diff --git a/CMakeLists.txt b/CMakeLists.txt index 5aa5c69..d34cc3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ project(xbot LANGUAGES C CXX ) -find_package(Boost REQUIRED COMPONENTS log) +find_package(Boost 1.83.0 CONFIG COMPONENTS log) find_package(PkgConfig REQUIRED) pkg_check_modules(LIBHS libhs REQUIRED IMPORTED_TARGET) @@ -37,10 +37,13 @@ add_custom_command( DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/irc_commands.gperf VERBATIM) +add_subdirectory(mybase64) + add_executable(xbot main.cpp irc_commands.inc ircmsg.cpp settings.cpp connection.cpp snote_thread.cpp watchdog_thread.cpp write_irc.cpp ping_thread.cpp irc_parse_thread.cpp registration_thread.cpp - self_thread.cpp command_thread.cpp priv_thread.cpp) + self_thread.cpp command_thread.cpp priv_thread.cpp + sasl_thread.cpp) target_include_directories(xbot PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -target_link_libraries(xbot PRIVATE Boost::log Boost::headers tomlplusplus_tomlplusplus eventpp PkgConfig::LIBHS) +target_link_libraries(xbot PRIVATE Boost::log Boost::headers tomlplusplus_tomlplusplus eventpp PkgConfig::LIBHS mybase64) diff --git a/command_thread.hpp b/command_thread.hpp index 38ec1ba..d1b7e66 100644 --- a/command_thread.hpp +++ b/command_thread.hpp @@ -8,11 +8,20 @@ class Connection; struct CommandEvent : Event { + /// @brief oper account of sender std::string_view oper; + + /// @brief nickserv acccount of sender std::string_view account; + + /// @brief nickname of sender std::string_view nick; + + /// @brief command name excluding sigil std::string_view command; - std::string_view arg; + + /// @brief complete argument excluding space after command + std::string_view arg; }; struct CommandThread diff --git a/main.cpp b/main.cpp index 6f9195d..c70ae0b 100644 --- a/main.cpp +++ b/main.cpp @@ -81,7 +81,8 @@ auto start(boost::asio::io_context & io, Settings const& settings) -> void [&io, &settings](std::exception_ptr e) { auto timer = std::make_shared(io); - timer->expires_from_now(5s); + + timer->expires_after(5s); timer->async_wait([&io, &settings, timer](auto) { start(io, settings); }); diff --git a/mybase64/CMakeLists.txt b/mybase64/CMakeLists.txt new file mode 100644 index 0000000..d5c6eab --- /dev/null +++ b/mybase64/CMakeLists.txt @@ -0,0 +1,2 @@ +add_library(mybase64 STATIC mybase64.cpp) +target_include_directories(mybase64 PUBLIC include) diff --git a/mybase64/include/mybase64.hpp b/mybase64/include/mybase64.hpp new file mode 100644 index 0000000..eb2ef11 --- /dev/null +++ b/mybase64/include/mybase64.hpp @@ -0,0 +1,44 @@ +/** + * @file mybase64.hpp + * @author Eric Mertens (emertens@gmail.com) + * @brief Base64 encoding and decoding + * + */ +#pragma once + +#include +#include + +namespace mybase64 +{ + +inline constexpr auto encoded_size(std::size_t len) -> std::size_t +{ + return (len + 2) / 3 * 4; +} + +inline constexpr auto decoded_size(std::size_t len) -> std::size_t +{ + return (len + 3) / 4 * 3; +} + +/** + * @brief Encode a string into base64 + * + * @param input input text + * @param output Target buffer for encoded value + */ +auto encode(std::string_view input, char* output) -> void; + +/** + * @brief Decode a base64 encoded string + * + * @param input Base64 input text + * @param output Target buffer for decoded value + * @param outlen Output parameter for decoded length + * @return true success + * @return false failure + */ +auto decode(std::string_view input, char* output, std::size_t* outlen) -> bool; + +} // namespace diff --git a/mybase64/mybase64.cpp b/mybase64/mybase64.cpp new file mode 100644 index 0000000..a48c9d0 --- /dev/null +++ b/mybase64/mybase64.cpp @@ -0,0 +1,113 @@ +#include "mybase64.hpp" + +#include +#include +#include + +namespace mybase64 +{ + +static_assert(CHAR_BIT == 8); + +auto encode(std::string_view const input, char* output) -> void +{ + static char const* const alphabet = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + auto cursor = std::begin(input); + auto const end = std::end(input); + + while (end - cursor >= 3) + { + uint32_t buffer = uint8_t(*cursor++); + buffer <<= 8; buffer |= uint8_t(*cursor++); + buffer <<= 8; buffer |= uint8_t(*cursor++); + + *output++ = alphabet[(buffer >> 6 * 3) % 64]; + *output++ = alphabet[(buffer >> 6 * 2) % 64]; + *output++ = alphabet[(buffer >> 6 * 1) % 64]; + *output++ = alphabet[(buffer >> 6 * 0) % 64]; + } + + if (cursor < end) + { + uint32_t buffer = uint8_t(*cursor++) << 10; + if (cursor < end) buffer |= uint8_t(*cursor) << 2; + + *output++ = alphabet[(buffer >> 12) % 64]; + *output++ = alphabet[(buffer >> 6) % 64]; + *output++ = cursor < end ? alphabet[(buffer % 64)] : '='; + *output++ = '='; + } + *output = '\0'; +} + +auto decode(std::string_view const input, char* const output, std::size_t* const outlen) -> bool +{ + static int8_t const alphabet_values[] = { + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, 0x3e, -1, -1, -1, 0x3f, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, + 0x3c, 0x3d, -1, -1, -1, -1, -1, -1, + -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, + 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, -1, -1, -1, -1, -1, + -1, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, + 0x31, 0x32, 0x33, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + }; + + uint32_t buffer = 1; + char* cursor = output; + + for (char c : input) { + int8_t const value = alphabet_values[uint8_t(c)]; + if (-1 == value) continue; + + buffer = (buffer << 6) | value; + + if (buffer & 1<<6*4) { + *cursor++ = buffer >> 8*2; + *cursor++ = buffer >> 8*1; + *cursor++ = buffer >> 8*0; + buffer = 1; + } + } + + if (buffer & 1<<6*3) { + *cursor++ = buffer >> 10; + *cursor++ = buffer >> 2; + } else if (buffer & 1<<6*2) { + *cursor++ = buffer >> 4; + } else if (buffer & 1<<6*1) { + return false; + } + *outlen = cursor - output; + return true; +} + +} // namespace diff --git a/registration_thread.cpp b/registration_thread.cpp index 6c2c1dd..d8e0a08 100644 --- a/registration_thread.cpp +++ b/registration_thread.cpp @@ -30,8 +30,6 @@ auto RegistrationThread::on_connect() -> void send_pass(connection_, password_); send_user(connection_, username_, realname_); send_nick(connection_, nickname_); - - connection_.remove_listener(connect_handle_); } auto RegistrationThread::send_req() -> void @@ -53,7 +51,7 @@ auto RegistrationThread::send_req() -> void "solanum.chat/realhost", }; - for (auto cap : want) + for (auto const cap : want) { if (caps.contains(cap)) { @@ -62,9 +60,7 @@ auto RegistrationThread::send_req() -> void outstanding.insert(cap); } } - - connection_.remove_listener(message_handle_); - + if (not outstanding.empty()) { request.pop_back(); @@ -90,8 +86,8 @@ auto RegistrationThread::on_msg_cap_ack(IrcMsg const& msg) -> void ); if (outstanding.empty()) { - send_cap_end(connection_); connection_.remove_listener(message_handle_); + send_cap_end(connection_); } } @@ -135,6 +131,7 @@ auto RegistrationThread::on_msg_cap_ls(IrcMsg const& msg) -> void if (last) { + connection_.remove_listener(message_handle_); send_req(); } } @@ -153,6 +150,7 @@ auto RegistrationThread::start( thread->connect_handle_ = connection.add_listener([thread](ConnectEvent const&) { + thread->connection_.remove_listener(thread->connect_handle_); thread->on_connect(); }); diff --git a/sasl_thread.cpp b/sasl_thread.cpp new file mode 100644 index 0000000..11b3c4c --- /dev/null +++ b/sasl_thread.cpp @@ -0,0 +1,55 @@ +#include "sasl_thread.hpp" + +#include + +#include + +#include "connection.hpp" +#include "write_irc.hpp" +#include "irc_parse_thread.hpp" +#include "ircmsg.hpp" + +auto SaslThread::start(Connection& connection) -> std::shared_ptr +{ + auto thread = std::make_shared(connection); + + connection.add_listener([thread](IrcMsgEvent const& event){ + if (event.command == IrcCommand::AUTHENTICATE) + { + thread->on_authenticate(event.irc.args[0]); + } + }); + + return thread; +} + +auto SaslThread::on_authenticate(std::string_view chunk) -> void +{ + if (chunk != "+") { + buffer_ += chunk; + } + + if (chunk.size() != 400) + { + std::string decoded; + decoded.resize(mybase64::decoded_size(buffer_.size())); + std::size_t len; + + if (mybase64::decode(buffer_, decoded.data(), &len)) + { + decoded.resize(len); + connection_.make_event(std::move(decoded)); + } else { + BOOST_LOG_TRIVIAL(debug) << "Invalid AUTHENTICATE base64"; + send_authenticate(connection_, "*"); // abort SASL + } + + buffer_.clear(); + } + else if (buffer_.size() > MAX_BUFFER) + { + BOOST_LOG_TRIVIAL(debug) << "AUTHENTICATE buffer overflow"; + buffer_.clear(); + send_authenticate(connection_, "*"); // abort SASL + } +} diff --git a/sasl_thread.hpp b/sasl_thread.hpp new file mode 100644 index 0000000..43b7944 --- /dev/null +++ b/sasl_thread.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include +#include + +#include "event.hpp" + +struct Connection; + +class SaslMechanism +{ +public: + virtual ~SaslMechanism() {} + + virtual auto mechanism_name() const -> std::string = 0; + virtual auto step(std::string_view msg) -> std::optional = 0; + virtual auto is_complete() const -> bool = 0; +}; + +class SaslPlain final : public SaslMechanism +{ + std::string authcid_; + std::string authzid_; + std::string password_; + bool complete_; + +public: + SaslPlain(std::string username, std::string password) + : authcid_{std::move(username)} + , password_{std::move(password)} + , complete_{false} + {} + + auto mechanism_name() const -> std::string override + { + return "PLAIN"; + } + + auto step(std::string_view msg) -> std::optional override { + if (complete_) { + return {}; + } else { + std::string reply; + + reply += authzid_; + reply += '\0'; + reply += authcid_; + reply += '\0'; + reply += password_; + + complete_ = true; + + return {std::move(reply)}; + } + } + + auto is_complete() const -> bool override + { + return complete_; + } +}; + +struct SaslComplete : Event +{ + bool success; +}; + +struct SaslMessage : Event +{ + SaslMessage(std::string message) : message{std::move(message)} {} + std::string message; +}; + +class SaslThread +{ + Connection& connection_; + std::string buffer_; + + const std::size_t MAX_BUFFER = 1024; + +public: + SaslThread(Connection& connection) : connection_{connection} {} + + static auto start(Connection& connection) -> std::shared_ptr; + + auto on_authenticate(std::string_view) -> void; +}; diff --git a/self_thread.cpp b/self_thread.cpp index 2a23a7d..2efcc18 100644 --- a/self_thread.cpp +++ b/self_thread.cpp @@ -1,7 +1,5 @@ #include "self_thread.hpp" -#include - #include "connection.hpp" #include "ircmsg.hpp" #include "irc_parse_thread.hpp" diff --git a/watchdog_thread.cpp b/watchdog_thread.cpp index dd71f42..d8dd5fa 100644 --- a/watchdog_thread.cpp +++ b/watchdog_thread.cpp @@ -19,7 +19,7 @@ WatchdogThread::WatchdogThread(Connection& connection) auto WatchdogThread::on_activity() -> void { stalled_ = false; - timer_.expires_from_now(WatchdogThread::TIMEOUT); + timer_.expires_after(WatchdogThread::TIMEOUT); } auto WatchdogThread::start_timer() @@ -46,7 +46,7 @@ auto WatchdogThread::on_timeout() -> void { send_ping(connection_, "watchdog"); stalled_ = true; - timer_.expires_from_now(WatchdogThread::TIMEOUT); + timer_.expires_after(WatchdogThread::TIMEOUT); start_timer(); } } diff --git a/write_irc.cpp b/write_irc.cpp index 19002e1..55a0e4a 100644 --- a/write_irc.cpp +++ b/write_irc.cpp @@ -100,3 +100,8 @@ auto send_notice(Connection& connection, std::string_view target, std::string_vi { write_irc(connection, "NOTICE", target, message); } + +auto send_authenticate(Connection& connection, std::string_view message) -> void +{ + write_irc(connection, "AUTHENTICATE", message); +} diff --git a/write_irc.hpp b/write_irc.hpp index 4a21ebe..e62600f 100644 --- a/write_irc.hpp +++ b/write_irc.hpp @@ -14,3 +14,4 @@ auto send_cap_end(Connection&) -> void; auto send_cap_req(Connection&, std::string_view) -> void; auto send_privmsg(Connection&, std::string_view, std::string_view) -> void; auto send_notice(Connection&, std::string_view, std::string_view) -> void; +auto send_authenticate(Connection& connection, std::string_view message) -> void;