From a6b6a4179c4510042e8b7ce235cab97c40d1c536 Mon Sep 17 00:00:00 2001 From: Eric Mertens Date: Wed, 22 Nov 2023 19:59:34 -0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 6 + CMakeLists.txt | 24 ++++ CMakePresets.json | 24 ++++ connection.cpp | 108 +++++++++++++++++ connection.hpp | 97 +++++++++++++++ irc_thread.hpp | 27 +++++ ircmsg.cpp | 149 +++++++++++++++++++++++ ircmsg.hpp | 62 ++++++++++ linebuffer.hpp | 97 +++++++++++++++ main.cpp | 302 ++++++++++++++++++++++++++++++++++++++++++++++ settings.cpp | 17 +++ settings.hpp | 17 +++ 12 files changed, 930 insertions(+) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 CMakePresets.json create mode 100644 connection.cpp create mode 100644 connection.hpp create mode 100644 irc_thread.hpp create mode 100644 ircmsg.cpp create mode 100644 ircmsg.hpp create mode 100644 linebuffer.hpp create mode 100644 main.cpp create mode 100644 settings.cpp create mode 100644 settings.hpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f58b540 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/out +/config.toml +/.ccls +/archive +/.vscode +/compile_commands.json diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..19ffbf7 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.13) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) +project(xbot + VERSION 1 + LANGUAGES C CXX +) + +find_package(PkgConfig REQUIRED) + +pkg_check_modules(LIBIDN IMPORTED_TARGET libidn) +find_package(Boost REQUIRED) +find_package(OpenSSL REQUIRED) + +include(FetchContent) +FetchContent_Declare( + tomlplusplus + GIT_REPOSITORY https://github.com/marzer/tomlplusplus.git + GIT_TAG v3.4.0 +) +FetchContent_MakeAvailable(tomlplusplus) + +add_executable(xbot main.cpp ircmsg.cpp settings.cpp connection.cpp) +target_link_libraries(xbot PRIVATE Boost::headers OpenSSL::SSL tomlplusplus_tomlplusplus) diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 0000000..49b13c6 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,24 @@ +{ + "version": 2, + "configurePresets": [ + { + "name": "arm-mac", + "displayName": "Configure preset using toolchain file", + "description": "Sets Ninja generator, build and install directory", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_TOOLCHAIN_FILE": "", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", + "CMAKE_EXPORT_COMPILE_COMMANDS": "On" + } + } + ], + "buildPresets": [ + { + "name": "arm-mac", + "configurePreset": "arm-mac" + } + ] +} diff --git a/connection.cpp b/connection.cpp new file mode 100644 index 0000000..817718b --- /dev/null +++ b/connection.cpp @@ -0,0 +1,108 @@ +#include "connection.hpp" + +auto Connection::writer_() -> void +{ + std::vector buffers; + buffers.reserve(write_strings_.size()); + for (auto const& elt : write_strings_) + { + buffers.push_back(boost::asio::buffer(elt)); + } + boost::asio::async_write( + stream_, + buffers, + [weak = weak_from_this() + ,strings = std::move(write_strings_) + ](boost::system::error_code const& error, std::size_t) + { + if (not error) + { + if (auto self = weak.lock()) + { + self->writer(); + } + } + }); + write_strings_.clear(); +} + +auto Connection::writer() -> void +{ + if (write_strings_.empty()) + { + write_timer_.async_wait([weak = weak_from_this()](auto){ + if (auto self = weak.lock()) + { + if (not self->write_strings_.empty()) + { + self->writer_(); + } + } + }); + } + else + { + writer_(); + } +} + +auto Connection::connect( + boost::asio::io_context & io, + Settings settings) +-> boost::asio::awaitable +{ + auto self = shared_from_this(); + + { + auto resolver = boost::asio::ip::tcp::resolver{io}; + auto const endpoints = co_await resolver.async_resolve(settings.host, settings.service, boost::asio::use_awaitable); + auto const endpoint = co_await boost::asio::async_connect(stream_, endpoints, boost::asio::use_awaitable); + + self->writer(); + dispatch(&IrcThread::on_connect); + } + + for(LineBuffer buffer{32'768};;) + { + boost::system::error_code error; + auto const n = co_await stream_.async_read_some(buffer.get_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, error)); + if (error) + { + break; + } + buffer.add_bytes(n, [this](char * line) { + dispatch(&IrcThread::on_msg, parse_irc_message(line)); + }); + } + + dispatch(&IrcThread::on_disconnect); +} + +auto Connection::write(std::string message) -> void +{ + auto const need_cancel = write_strings_.empty(); + + message += "\r\n"; + write_strings_.push_back(std::move(message)); + + if (need_cancel) + { + write_timer_.cancel_one(); + } +} + +auto Connection::write(std::string front, std::string_view last) -> void + { + auto const is_invalid = [](char x) -> bool { + return x == '\0' || x == '\r' || x == '\n'; + }; + + if (last.end() != std::find_if(last.begin(), last.end(), is_invalid)) + { + throw std::runtime_error{"bad irc argument"}; + } + + front += " :"; + front += last; + write(std::move(front)); + } diff --git a/connection.hpp b/connection.hpp new file mode 100644 index 0000000..cbc325e --- /dev/null +++ b/connection.hpp @@ -0,0 +1,97 @@ +#pragma once + +#include "irc_thread.hpp" +#include "ircmsg.hpp" +#include "linebuffer.hpp" +#include "settings.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class Connection : public std::enable_shared_from_this +{ + boost::asio::ip::tcp::socket stream_; + boost::asio::steady_timer write_timer_; + + std::list write_strings_; + std::vector> threads_; + + auto writer() -> void; + auto writer_() -> void; + + template + auto dispatch( + IrcThread::callback_result (IrcThread::* method)(Args...), + Args... args + ) -> void + { + std::vector> work; + work.swap(threads_); + std::sort(work.begin(), work.end(), [](auto const& a, auto const& b) { return a->priority() < b->priority(); }); + + std::size_t const n = work.size(); + for (std::size_t i = 0; i < n; i++) + { + auto const [thread_outcome, msg_outcome] = (work[i].get()->*method)(args...); + if (thread_outcome == ThreadOutcome::Continue) + { + threads_.push_back(std::move(work[i])); + } + if (msg_outcome == EventOutcome::Consume) + { + std::move(work.begin() + i + 1, work.end(), std::back_inserter(threads_)); + break; + } + } + } + +public: + Connection(boost::asio::io_context & io) + : stream_{io} + , write_timer_{io, std::chrono::steady_clock::time_point::max()} + { + } + + auto listen(std::unique_ptr thread) -> void + { + threads_.push_back(std::move(thread)); + } + + auto write(std::string front, std::string_view last) -> void; + + template + auto write(std::string front, std::string_view next, Args ...rest) -> void + { + auto const is_invalid = [](char x) -> bool { + return x == '\0' || x == '\r' || x == '\n' || x == ' '; + }; + + if (next.empty() || next.end() != std::find_if(next.begin(), next.end(), is_invalid)) + { + throw std::runtime_error{"bad irc argument"}; + } + + front += " "; + front += next; + write(std::move(front), rest...); + } + + auto write(std::string message) -> void; + + auto connect( + boost::asio::io_context & io, + Settings settings + ) -> boost::asio::awaitable; +}; + diff --git a/irc_thread.hpp b/irc_thread.hpp new file mode 100644 index 0000000..7da32f5 --- /dev/null +++ b/irc_thread.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "ircmsg.hpp" + +enum class EventOutcome +{ + Pass, + Consume, +}; + +enum class ThreadOutcome +{ + Continue, + Finish, +}; + +struct IrcThread +{ + using priority_type = std::uint64_t; + using callback_result = std::pair; + virtual ~IrcThread() {} + + virtual auto on_connect() -> callback_result { return {}; } + virtual auto on_disconnect() -> callback_result { return {}; }; + virtual auto on_msg(ircmsg const&) -> callback_result { return {}; }; + virtual auto priority() const -> priority_type = 0; +}; diff --git a/ircmsg.cpp b/ircmsg.cpp new file mode 100644 index 0000000..fa3542f --- /dev/null +++ b/ircmsg.cpp @@ -0,0 +1,149 @@ +#include +#include + +#include "ircmsg.hpp" + +namespace { +class parser { + char* msg_; + inline static char empty[1]; + + inline void trim() { + while (*msg_ == ' ') msg_++; + } + +public: + parser(char* msg) : msg_(msg) { + if (msg_ == nullptr) { + msg_ = empty; + } else { + trim(); + } + } + + char* word() { + auto const start = msg_; + while (*msg_ != '\0' && *msg_ != ' ') msg_++; + if (*msg_ != '\0') { // prepare for next token + *msg_++ = '\0'; + trim(); + } + return start; + } + + bool match(char c) { + if (c == *msg_) { + msg_++; + return true; + } + return false; + } + + bool isempty() const { + return *msg_ == '\0'; + } + + char const* peek() { + return msg_; + } +}; + +std::string_view unescape_tag_value(char* const val) +{ + // only start copying at the first escape character + // skip everything before that + auto cursor = strchr(val, '\\'); + if (cursor == nullptr) { return {val}; } + + auto write = cursor; + for (; *cursor; cursor++) + { + if (*cursor == '\\') + { + cursor++; + switch (*cursor) + { + default : *write++ = *cursor; break; + case ':' : *write++ = ';' ; break; + case 's' : *write++ = ' ' ; break; + case 'r' : *write++ = '\r' ; break; + case 'n' : *write++ = '\n' ; break; + case '\0': return {val, write}; + } + } + else + { + *write++ = *cursor; + } + } + return {val, write}; +} + +} // namespace + +auto parse_irc_tags(char* str) -> std::vector +{ + std::vector tags; + + do { + auto val = strsep(&str, ";"); + auto key = strsep(&val, "="); + if ('\0' == *key) { + throw irc_parse_error(irc_error_code::MISSING_TAG); + } + if (nullptr == val) { + tags.emplace_back(key, ""); + } else { + tags.emplace_back(std::string_view{key, val-1}, unescape_tag_value(val)); + } + } while(nullptr != str); + + return tags; +} + +auto parse_irc_message(char* msg) -> ircmsg +{ + parser p {msg}; + ircmsg out; + + /* MESSAGE TAGS */ + if (p.match('@')) { + out.tags = parse_irc_tags(p.word()); + } + + /* MESSAGE SOURCE */ + if (p.match(':')) { + out.source = p.word(); + } + + /* MESSAGE COMMANDS */ + out.command = p.word(); + if (out.command.empty()) { + throw irc_parse_error{irc_error_code::MISSING_COMMAND}; + } + + /* MESSAGE ARGUMENTS */ + while (!p.isempty()) { + if (p.match(':')) { + out.args.push_back(p.peek()); + break; + } + out.args.push_back(p.word()); + } + + return out; +} + +auto ircmsg::hassource() const -> bool +{ + return source.data() != nullptr; +} + +auto operator<<(std::ostream& out, irc_error_code code) -> std::ostream& +{ + switch(code) { + case irc_error_code::MISSING_COMMAND: out << "MISSING COMMAND"; return out; + case irc_error_code::MISSING_TAG: out << "MISSING TAG"; return out; + default: return out; + } +} diff --git a/ircmsg.hpp b/ircmsg.hpp new file mode 100644 index 0000000..e50b385 --- /dev/null +++ b/ircmsg.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include + +struct irctag +{ + std::string_view key; + std::string_view val; + + irctag(std::string_view key, std::string_view val) : key{key}, val{val} {} + + friend auto operator==(irctag const&, irctag const&) -> bool = default; +}; + +struct ircmsg +{ + std::vector tags; + std::vector args; + std::string_view source; + std::string_view command; + + ircmsg() = default; + + ircmsg( + std::vector && tags, + std::string_view source, + std::string_view command, + std::vector && args) + : tags(std::move(tags)), + args(std::move(args)), + source{source}, + command{command} {} + + bool hassource() const; + + friend bool operator==(ircmsg const&, ircmsg const&) = default; +}; + +enum class irc_error_code { + MISSING_TAG, + MISSING_COMMAND, +}; + +auto operator<<(std::ostream& out, irc_error_code) -> std::ostream&; + +struct irc_parse_error : public std::exception { + irc_error_code code; + irc_parse_error(irc_error_code code) : code(code) {} +}; + +/** + * Parses the given IRC message into a structured format. + * The original message is mangled to store string fragments + * that are pointed to by the structured message type. + * + * Returns zero for success, non-zero for parse error. + */ +auto parse_irc_message(char* msg) -> ircmsg; + +auto parse_irc_tags(char* msg) -> std::vector; diff --git a/linebuffer.hpp b/linebuffer.hpp new file mode 100644 index 0000000..a832ec0 --- /dev/null +++ b/linebuffer.hpp @@ -0,0 +1,97 @@ +#pragma once +/** + * @file linebuffer.hpp + * @author Eric Mertens + * @brief A line buffering class + * @version 0.1 + * @date 2023-08-22 + * + * @copyright Copyright (c) 2023 + * + */ + +#include + +#include +#include +#include + +/** + * @brief Fixed-size buffer with line-oriented dispatch + * + */ +class LineBuffer +{ + std::vector buffer; + + // [buffer.begin(), end_) contains buffered data + // [end_, buffer.end()) is available buffer space + std::vector::iterator end_; + +public: + /** + * @brief Construct a new Line Buffer object + * + * @param n Buffer size + */ + LineBuffer(std::size_t n) : buffer(n), end_{buffer.begin()} {} + + /** + * @brief Get the available buffer space + * + * @return boost::asio::mutable_buffer + */ + auto get_buffer() -> boost::asio::mutable_buffer + { + return boost::asio::buffer(&*end_, std::distance(end_, buffer.end())); + } + + /** + * @brief Commit new buffer bytes and dispatch line callback + * + * The first n bytes of the buffer will be considered to be + * populated. The line callback function will be called once + * per completed line. Those lines are removed from the buffer + * and the is ready for additional calls to get_buffer and + * add_bytes. + * + * @param n Bytes written to the last call of get_buffer + * @param line_cb Callback function to run on each completed line + */ + auto add_bytes(std::size_t n, std::invocable auto line_cb) -> void + { + auto const start = end_; + std::advance(end_, n); + + // new data is now located in [start, end_) + + // cursor marks the beginning of the current line + auto cursor = buffer.begin(); + + for (auto nl = std::find(start, end_, '\n'); + nl != end_; + nl = std::find(cursor, end_, '\n')) + { + // Null-terminate the line. Support both \n and \r\n + if (cursor < nl && *std::prev(nl) == '\r') + { + *std::prev(nl) = '\0'; + } + else + { + *nl = '\0'; + } + + line_cb(&*cursor); + + cursor = std::next(nl); + } + + // If any lines were processed, move all processed lines to + // the front of the buffer + if (cursor != buffer.begin()) + { + end_ = std::move(cursor, end_, buffer.begin()); + } + } +}; diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..3c39a7b --- /dev/null +++ b/main.cpp @@ -0,0 +1,302 @@ +#include +#include + +#include "linebuffer.hpp" +#include "ircmsg.hpp" +#include "settings.hpp" +#include "irc_thread.hpp" +#include "connection.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std::chrono_literals; + + +struct ChatThread : public IrcThread +{ + auto priority() const -> priority_type override + { + return 100; + } + auto on_msg(ircmsg const& irc) -> std::pair override + { + if (irc.command == "PRIVMSG" && 2 == irc.args.size()) + { + std::cout << "Chat from " << irc.source << ": " << irc.args[1] << std::endl; + return {ThreadOutcome::Continue, EventOutcome::Pass}; + } + else + { + return {ThreadOutcome::Continue, EventOutcome::Pass}; + } + } +}; + +struct UnhandledThread : public IrcThread +{ + auto priority() const -> priority_type override + { + return std::numeric_limits::max(); + } + + auto on_msg(ircmsg const& irc) -> std::pair override + { + std::cout << "Unhandled message " << irc.command; + for (auto const arg : irc.args) + { + std::cout << " " << arg; + } + std::cout << "\n"; + return {ThreadOutcome::Continue, EventOutcome::Pass}; + } +}; + +class PingThread : public IrcThread +{ + Connection * connection_; + +public: + PingThread(Connection * connection) noexcept : connection_{connection} {} + + auto priority() const -> priority_type override + { + return 0; + } + + auto on_msg(ircmsg const& irc) -> std::pair override + { + if (irc.command == "PING" && 1 == irc.args.size()) + { + connection_->write("PONG", irc.args[0]); + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + else + { + return {}; + } + } +}; + +struct RegistrationThread : IrcThread +{ + Connection * connection_; + std::string password_; + std::string username_; + std::string realname_; + std::string nickname_; + + std::unordered_map caps; + std::unordered_set outstanding; + + enum class Stage + { + LsReply, + AckReply, + }; + + Stage stage_; + + RegistrationThread( + Connection * connection_, + std::string password, + std::string username, + std::string realname, + std::string nickname + ) + : connection_{connection_} + , password_{password} + , username_{username} + , realname_{realname} + , nickname_{nickname} + , stage_{Stage::LsReply} + {} + + auto priority() const -> priority_type override { return 1; } + auto on_connect() -> IrcThread::callback_result override + { + connection_->write("CAP", "LS", "302"); + connection_->write("PASS", password_); + connection_->write("USER", username_, "*", "*", realname_); + connection_->write("NICK", nickname_); + return {}; + } + + auto send_req() -> IrcThread::callback_result + { + std::string request; + char const* want[] = { "extended-join", "account-notify", "draft/chathistory", "batch", "soju.im/no-implicit-names", "chghost", "setname", "account-tag", "solanum.chat/oper", "solanum.chat/identify-msg", "solanum.chat/realhost", "server-time", "invite-notify", "extended-join" }; + for (auto cap : want) + { + if (caps.contains(cap)) + { + request.append(cap); + request.push_back(' '); + outstanding.insert(cap); + } + } + if (not outstanding.empty()) + { + request.pop_back(); + connection_->write("CAP", "REQ", request); + stage_ = Stage::AckReply; + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + else + { + connection_->write("CAP", "END"); + return {ThreadOutcome::Finish, EventOutcome::Consume}; + } + } + + auto capack(ircmsg const& msg) -> IrcThread::callback_result + { + auto const n = msg.args.size(); + if ("CAP" == msg.command && n >= 2 && "*" == msg.args[0] && "ACK" == msg.args[1]) + { + auto in = std::istringstream{std::string{msg.args[2]}}; + std::for_each( + std::istream_iterator{in}, + std::istream_iterator{}, + [this](std::string x) { + outstanding.erase(x); + } + ); + if (outstanding.empty()) + { + connection_->write("CAP","END"); + return {ThreadOutcome::Finish, EventOutcome::Consume}; + } + else + { + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + } + else + { + return {}; + } + } + + auto capls(ircmsg const& msg) -> IrcThread::callback_result + { + auto const n = msg.args.size(); + if ("CAP" == msg.command && n >= 2 && "*" == msg.args[0] && "LS" == msg.args[1]) + { + std::string_view const* kvs; + bool last; + + if (3 == n) + { + kvs = &msg.args[2]; + last = true; + } + else if (4 == n && "*" == msg.args[2]) + { + kvs = &msg.args[3]; + last = false; + } + else + { + return {}; + } + + auto in = std::istringstream{std::string{*kvs}}; + + std::for_each( + std::istream_iterator{in}, + std::istream_iterator{}, + [this](std::string x) { + auto const eq = x.find('='); + if (eq == x.npos) + { + caps.emplace(x, std::string{}); + } + else + { + caps.emplace(std::string{x, 0, eq}, std::string{x, eq+1, x.npos}); + } + } + ); + + if (last) + { + return send_req(); + } + + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + else + { + return {}; + } + + } + + auto on_msg(ircmsg const& msg) -> IrcThread::callback_result override + { + switch (stage_) + { + case Stage::LsReply: return capls(msg); + case Stage::AckReply: return capack(msg); + default: return {}; + } + } +}; + +auto start(boost::asio::io_context & io, Settings const& settings) -> void +{ + auto connection = std::make_shared(io); + + connection->listen(std::make_unique(connection.get())); + connection->listen(std::make_unique(connection.get(), settings.password, settings.username, settings.realname, settings.nickname)); + connection->listen(std::make_unique()); + connection->listen(std::make_unique()); + + boost::asio::co_spawn( + io, + connection->connect(io, settings), + [&io, &settings](std::exception_ptr e) + { + auto timer = boost::asio::steady_timer{io}; + timer.expires_from_now(5s); + timer.async_wait([&io, &settings](auto) { + start(io, settings); + }); + }); +} + +auto get_settings() -> Settings +{ + if (auto config_stream = std::ifstream {"config.toml"}) + { + return Settings::from_stream(config_stream); + } + else + { + std::cerr << "Unable to open config.toml\n"; + std::exit(1); + } +} + +auto main() -> int +{ + auto const settings = get_settings(); + auto io = boost::asio::io_context{}; + start(io, settings); + io.run(); +} diff --git a/settings.cpp b/settings.cpp new file mode 100644 index 0000000..20e1a6c --- /dev/null +++ b/settings.cpp @@ -0,0 +1,17 @@ +#include "settings.hpp" + +#define TOML_ENABLE_FORMATTERS 0 +#include + +auto Settings::from_stream(std::istream & in) -> Settings +{ + auto const config = toml::parse(in); + return Settings{ + .host = config["host"].value_or(std::string{"*"}), + .service = config["service"].value_or(std::string{"*"}), + .password = config["password"].value_or(std::string{"*"}), + .username = config["username"].value_or(std::string{"*"}), + .realname = config["realname"].value_or(std::string{"*"}), + .nickname = config["nickname"].value_or(std::string{"*"}) + }; +} diff --git a/settings.hpp b/settings.hpp new file mode 100644 index 0000000..8cba813 --- /dev/null +++ b/settings.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +struct Settings +{ + std::string host; + std::string service; + std::string password; + std::string username; + std::string realname; + std::string nickname; + + static auto from_stream(std::istream & in) -> Settings; +}; +