From 53050bb2a19f6ff3d347851579755e412d82ab44 Mon Sep 17 00:00:00 2001 From: Eric Mertens Date: Sat, 25 Nov 2023 09:22:55 -0800 Subject: [PATCH] pre-eventpp snapshot --- CMakeLists.txt | 2 +- connection.cpp | 61 +++++----- connection.hpp | 91 +++++--------- irc_parse_thread.cpp | 21 ++++ irc_parse_thread.hpp | 20 ++++ irc_thread.hpp | 27 ----- ircmsg.cpp | 12 +- ircmsg.hpp | 10 +- main.cpp | 260 ++++++---------------------------------- ping_thread.cpp | 25 ++++ ping_thread.hpp | 15 +++ registration_thread.cpp | 164 +++++++++++++++++++++++++ registration_thread.hpp | 51 ++++++++ thread.cpp | 46 +++++++ thread.hpp | 42 +++++++ watchdog_thread.cpp | 42 +++++++ watchdog_thread.hpp | 18 +++ write_irc.cpp | 23 ++++ write_irc.hpp | 24 ++++ 19 files changed, 606 insertions(+), 348 deletions(-) create mode 100644 irc_parse_thread.cpp create mode 100644 irc_parse_thread.hpp delete mode 100644 irc_thread.hpp create mode 100644 ping_thread.cpp create mode 100644 ping_thread.hpp create mode 100644 registration_thread.cpp create mode 100644 registration_thread.hpp create mode 100644 thread.cpp create mode 100644 thread.hpp create mode 100644 watchdog_thread.cpp create mode 100644 watchdog_thread.hpp create mode 100644 write_irc.cpp create mode 100644 write_irc.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 19ffbf7..0768815 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,5 +20,5 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(tomlplusplus) -add_executable(xbot main.cpp ircmsg.cpp settings.cpp connection.cpp) +add_executable(xbot main.cpp ircmsg.cpp settings.cpp connection.cpp thread.cpp watchdog_thread.cpp irc_parse_thread.cpp write_irc.cpp ping_thread.cpp registration_thread.cpp) target_link_libraries(xbot PRIVATE Boost::headers OpenSSL::SSL tomlplusplus_tomlplusplus) diff --git a/connection.cpp b/connection.cpp index 817718b..984a090 100644 --- a/connection.cpp +++ b/connection.cpp @@ -1,5 +1,21 @@ #include "connection.hpp" +Connection::Connection(boost::asio::io_context & io) +: stream_{io} +, write_timer_{io, std::chrono::steady_clock::time_point::max()} +{ +} + +auto Connection::add_thread(std::shared_ptr thread) -> void +{ + dispatcher_.add_thread(std::move(thread)); +} + +auto Connection::add_event(std::shared_ptr event) -> void +{ + dispatcher_.dispatch(std::move(event)); +} + auto Connection::writer_() -> void { std::vector buffers; @@ -48,20 +64,24 @@ auto Connection::writer() -> void auto Connection::connect( boost::asio::io_context & io, - Settings settings) --> boost::asio::awaitable + std::string host, + std::string port +) -> boost::asio::awaitable { - auto self = shared_from_this(); + using namespace std::placeholders; + + // keep connection alive while coroutine is active + auto const self = shared_from_this(); { auto resolver = boost::asio::ip::tcp::resolver{io}; - auto const endpoints = co_await resolver.async_resolve(settings.host, settings.service, boost::asio::use_awaitable); + auto const endpoints = co_await resolver.async_resolve(host, port, boost::asio::use_awaitable); auto const endpoint = co_await boost::asio::async_connect(stream_, endpoints, boost::asio::use_awaitable); - self->writer(); - dispatch(&IrcThread::on_connect); + make_event(); } + self->writer(); for(LineBuffer buffer{32'768};;) { boost::system::error_code error; @@ -71,38 +91,25 @@ auto Connection::connect( break; } buffer.add_bytes(n, [this](char * line) { - dispatch(&IrcThread::on_msg, parse_irc_message(line)); + make_event(line); }); } - dispatch(&IrcThread::on_disconnect); + make_event(); } -auto Connection::write(std::string message) -> void +auto Connection::write_raw(std::string message) -> void { + std::cout << "Writing " << message; auto const need_cancel = write_strings_.empty(); - - message += "\r\n"; write_strings_.push_back(std::move(message)); - if (need_cancel) { write_timer_.cancel_one(); } } -auto Connection::write(std::string front, std::string_view last) -> void - { - auto const is_invalid = [](char x) -> bool { - return x == '\0' || x == '\r' || x == '\n'; - }; - - if (last.end() != std::find_if(last.begin(), last.end(), is_invalid)) - { - throw std::runtime_error{"bad irc argument"}; - } - - front += " :"; - front += last; - write(std::move(front)); - } +auto Connection::close() -> void +{ + stream_.close(); +} \ No newline at end of file diff --git a/connection.hpp b/connection.hpp index cbc325e..aa02679 100644 --- a/connection.hpp +++ b/connection.hpp @@ -1,14 +1,13 @@ #pragma once -#include "irc_thread.hpp" -#include "ircmsg.hpp" #include "linebuffer.hpp" #include "settings.hpp" +#include "thread.hpp" #include #include -#include +#include #include #include #include @@ -19,79 +18,49 @@ #include #include +struct ConnectEvent : Event +{ +}; + +struct DisconnectEvent : Event +{ +}; + +struct LineEvent : Event +{ + explicit LineEvent(char * line) : line{line} {} + char * line; +}; + class Connection : public std::enable_shared_from_this { boost::asio::ip::tcp::socket stream_; boost::asio::steady_timer write_timer_; - std::list write_strings_; - std::vector> threads_; + Dispatcher dispatcher_; auto writer() -> void; auto writer_() -> void; - template - auto dispatch( - IrcThread::callback_result (IrcThread::* method)(Args...), - Args... args - ) -> void - { - std::vector> work; - work.swap(threads_); - std::sort(work.begin(), work.end(), [](auto const& a, auto const& b) { return a->priority() < b->priority(); }); - - std::size_t const n = work.size(); - for (std::size_t i = 0; i < n; i++) - { - auto const [thread_outcome, msg_outcome] = (work[i].get()->*method)(args...); - if (thread_outcome == ThreadOutcome::Continue) - { - threads_.push_back(std::move(work[i])); - } - if (msg_outcome == EventOutcome::Consume) - { - std::move(work.begin() + i + 1, work.end(), std::back_inserter(threads_)); - break; - } - } - } - public: - Connection(boost::asio::io_context & io) - : stream_{io} - , write_timer_{io, std::chrono::steady_clock::time_point::max()} - { + Connection(boost::asio::io_context & io); + auto add_thread(std::shared_ptr thread) -> void; + auto add_event(std::shared_ptr event) -> void; + + template + auto make_event(Args&& ... args) { + add_event(std::make_shared(std::forward(args)...)); } - auto listen(std::unique_ptr thread) -> void - { - threads_.push_back(std::move(thread)); - } - - auto write(std::string front, std::string_view last) -> void; - - template - auto write(std::string front, std::string_view next, Args ...rest) -> void - { - auto const is_invalid = [](char x) -> bool { - return x == '\0' || x == '\r' || x == '\n' || x == ' '; - }; - - if (next.empty() || next.end() != std::find_if(next.begin(), next.end(), is_invalid)) - { - throw std::runtime_error{"bad irc argument"}; - } - - front += " "; - front += next; - write(std::move(front), rest...); - } - - auto write(std::string message) -> void; + /// Write bytes into the socket. Messages should be properly newline terminated. + auto write_raw(std::string message) -> void; auto connect( boost::asio::io_context & io, - Settings settings + std::string host, + std::string port ) -> boost::asio::awaitable; + + auto close() -> void; }; diff --git a/irc_parse_thread.cpp b/irc_parse_thread.cpp new file mode 100644 index 0000000..e9c104b --- /dev/null +++ b/irc_parse_thread.cpp @@ -0,0 +1,21 @@ +#include "irc_parse_thread.hpp" + +#include "connection.hpp" + +IrcParseThread::IrcParseThread(Connection * connection) noexcept +: connection_{connection} {} + +auto IrcParseThread::priority() const -> priority_type +{ + return 0; +} + +auto IrcParseThread::on_event(Event const& event) -> callback_result +{ + if (auto line_event = dynamic_cast(&event)) + { + connection_->make_event(parse_irc_message(line_event->line)); + return { ThreadOutcome::Continue, EventOutcome::Consume }; + } + return {}; +} \ No newline at end of file diff --git a/irc_parse_thread.hpp b/irc_parse_thread.hpp new file mode 100644 index 0000000..47a0b51 --- /dev/null +++ b/irc_parse_thread.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "thread.hpp" + +class Connection; + +struct IrcMsgEvent : Event +{ + IrcMsgEvent(IrcMsg irc) : irc{irc} {} + IrcMsg irc; +}; + +struct IrcParseThread : Thread +{ + Connection * connection_; + + IrcParseThread(Connection * connection) noexcept; + auto priority() const -> priority_type override; + auto on_event(Event const& event) -> callback_result override; +}; \ No newline at end of file diff --git a/irc_thread.hpp b/irc_thread.hpp deleted file mode 100644 index 7da32f5..0000000 --- a/irc_thread.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "ircmsg.hpp" - -enum class EventOutcome -{ - Pass, - Consume, -}; - -enum class ThreadOutcome -{ - Continue, - Finish, -}; - -struct IrcThread -{ - using priority_type = std::uint64_t; - using callback_result = std::pair; - virtual ~IrcThread() {} - - virtual auto on_connect() -> callback_result { return {}; } - virtual auto on_disconnect() -> callback_result { return {}; }; - virtual auto on_msg(ircmsg const&) -> callback_result { return {}; }; - virtual auto priority() const -> priority_type = 0; -}; diff --git a/ircmsg.cpp b/ircmsg.cpp index fa3542f..0b819d6 100644 --- a/ircmsg.cpp +++ b/ircmsg.cpp @@ -4,7 +4,7 @@ #include "ircmsg.hpp" namespace { -class parser { +class Parser { char* msg_; inline static char empty[1]; @@ -13,7 +13,7 @@ class parser { } public: - parser(char* msg) : msg_(msg) { + Parser(char* msg) : msg_(msg) { if (msg_ == nullptr) { msg_ = empty; } else { @@ -101,10 +101,10 @@ auto parse_irc_tags(char* str) -> std::vector return tags; } -auto parse_irc_message(char* msg) -> ircmsg +auto parse_irc_message(char* msg) -> IrcMsg { - parser p {msg}; - ircmsg out; + Parser p {msg}; + IrcMsg out; /* MESSAGE TAGS */ if (p.match('@')) { @@ -134,7 +134,7 @@ auto parse_irc_message(char* msg) -> ircmsg return out; } -auto ircmsg::hassource() const -> bool +auto IrcMsg::hassource() const -> bool { return source.data() != nullptr; } diff --git a/ircmsg.hpp b/ircmsg.hpp index e50b385..15c1904 100644 --- a/ircmsg.hpp +++ b/ircmsg.hpp @@ -14,16 +14,16 @@ struct irctag friend auto operator==(irctag const&, irctag const&) -> bool = default; }; -struct ircmsg +struct IrcMsg { std::vector tags; std::vector args; std::string_view source; std::string_view command; - ircmsg() = default; + IrcMsg() = default; - ircmsg( + IrcMsg( std::vector && tags, std::string_view source, std::string_view command, @@ -35,7 +35,7 @@ struct ircmsg bool hassource() const; - friend bool operator==(ircmsg const&, ircmsg const&) = default; + friend bool operator==(IrcMsg const&, IrcMsg const&) = default; }; enum class irc_error_code { @@ -57,6 +57,6 @@ struct irc_parse_error : public std::exception { * * Returns zero for success, non-zero for parse error. */ -auto parse_irc_message(char* msg) -> ircmsg; +auto parse_irc_message(char* msg) -> IrcMsg; auto parse_irc_tags(char* msg) -> std::vector; diff --git a/main.cpp b/main.cpp index 3c39a7b..b66229a 100644 --- a/main.cpp +++ b/main.cpp @@ -1,11 +1,15 @@ #include #include -#include "linebuffer.hpp" -#include "ircmsg.hpp" -#include "settings.hpp" -#include "irc_thread.hpp" #include "connection.hpp" +#include "ircmsg.hpp" +#include "linebuffer.hpp" +#include "settings.hpp" +#include "thread.hpp" + +#include "irc_parse_thread.hpp" +#include "ping_thread.hpp" +#include "registration_thread.hpp" #include #include @@ -27,234 +31,47 @@ using namespace std::chrono_literals; -struct ChatThread : public IrcThread +struct ChatThread : public Thread { auto priority() const -> priority_type override { return 100; } - auto on_msg(ircmsg const& irc) -> std::pair override + auto on_event(Event const& event) -> callback_result override { - if (irc.command == "PRIVMSG" && 2 == irc.args.size()) + if (auto const* irc_event = dynamic_cast(&event)) { - std::cout << "Chat from " << irc.source << ": " << irc.args[1] << std::endl; - return {ThreadOutcome::Continue, EventOutcome::Pass}; + auto const& irc = irc_event->irc; + if (irc.command == "PRIVMSG" && 2 == irc.args.size()) + { + std::cout << "Chat from " << irc.source << ": " << irc.args[1] << std::endl; + return { ThreadOutcome::Continue, EventOutcome::Consume }; + } } - else - { - return {ThreadOutcome::Continue, EventOutcome::Pass}; - } - } -}; - -struct UnhandledThread : public IrcThread -{ - auto priority() const -> priority_type override - { - return std::numeric_limits::max(); - } - - auto on_msg(ircmsg const& irc) -> std::pair override - { - std::cout << "Unhandled message " << irc.command; - for (auto const arg : irc.args) - { - std::cout << " " << arg; - } - std::cout << "\n"; - return {ThreadOutcome::Continue, EventOutcome::Pass}; - } -}; - -class PingThread : public IrcThread -{ - Connection * connection_; - -public: - PingThread(Connection * connection) noexcept : connection_{connection} {} - - auto priority() const -> priority_type override - { - return 0; - } - - auto on_msg(ircmsg const& irc) -> std::pair override - { - if (irc.command == "PING" && 1 == irc.args.size()) - { - connection_->write("PONG", irc.args[0]); - return {ThreadOutcome::Continue, EventOutcome::Consume}; - } - else - { - return {}; - } - } -}; - -struct RegistrationThread : IrcThread -{ - Connection * connection_; - std::string password_; - std::string username_; - std::string realname_; - std::string nickname_; - - std::unordered_map caps; - std::unordered_set outstanding; - - enum class Stage - { - LsReply, - AckReply, - }; - - Stage stage_; - - RegistrationThread( - Connection * connection_, - std::string password, - std::string username, - std::string realname, - std::string nickname - ) - : connection_{connection_} - , password_{password} - , username_{username} - , realname_{realname} - , nickname_{nickname} - , stage_{Stage::LsReply} - {} - - auto priority() const -> priority_type override { return 1; } - auto on_connect() -> IrcThread::callback_result override - { - connection_->write("CAP", "LS", "302"); - connection_->write("PASS", password_); - connection_->write("USER", username_, "*", "*", realname_); - connection_->write("NICK", nickname_); return {}; } +}; - auto send_req() -> IrcThread::callback_result +struct UnhandledThread : public Thread +{ + auto priority() const -> priority_type override { - std::string request; - char const* want[] = { "extended-join", "account-notify", "draft/chathistory", "batch", "soju.im/no-implicit-names", "chghost", "setname", "account-tag", "solanum.chat/oper", "solanum.chat/identify-msg", "solanum.chat/realhost", "server-time", "invite-notify", "extended-join" }; - for (auto cap : want) - { - if (caps.contains(cap)) - { - request.append(cap); - request.push_back(' '); - outstanding.insert(cap); - } - } - if (not outstanding.empty()) - { - request.pop_back(); - connection_->write("CAP", "REQ", request); - stage_ = Stage::AckReply; - return {ThreadOutcome::Continue, EventOutcome::Consume}; - } - else - { - connection_->write("CAP", "END"); - return {ThreadOutcome::Finish, EventOutcome::Consume}; - } + return std::numeric_limits::max(); } - auto capack(ircmsg const& msg) -> IrcThread::callback_result + auto on_event(Event const& event) -> callback_result override { - auto const n = msg.args.size(); - if ("CAP" == msg.command && n >= 2 && "*" == msg.args[0] && "ACK" == msg.args[1]) + if (auto irc_event = dynamic_cast(&event)) { - auto in = std::istringstream{std::string{msg.args[2]}}; - std::for_each( - std::istream_iterator{in}, - std::istream_iterator{}, - [this](std::string x) { - outstanding.erase(x); - } - ); - if (outstanding.empty()) + auto& irc = irc_event->irc; + std::cout << "Unhandled message " << irc.command; + for (auto const arg : irc.args) { - connection_->write("CAP","END"); - return {ThreadOutcome::Finish, EventOutcome::Consume}; - } - else - { - return {ThreadOutcome::Continue, EventOutcome::Consume}; + std::cout << " " << arg; } + std::cout << "\n"; } - else - { - return {}; - } - } - - auto capls(ircmsg const& msg) -> IrcThread::callback_result - { - auto const n = msg.args.size(); - if ("CAP" == msg.command && n >= 2 && "*" == msg.args[0] && "LS" == msg.args[1]) - { - std::string_view const* kvs; - bool last; - - if (3 == n) - { - kvs = &msg.args[2]; - last = true; - } - else if (4 == n && "*" == msg.args[2]) - { - kvs = &msg.args[3]; - last = false; - } - else - { - return {}; - } - - auto in = std::istringstream{std::string{*kvs}}; - - std::for_each( - std::istream_iterator{in}, - std::istream_iterator{}, - [this](std::string x) { - auto const eq = x.find('='); - if (eq == x.npos) - { - caps.emplace(x, std::string{}); - } - else - { - caps.emplace(std::string{x, 0, eq}, std::string{x, eq+1, x.npos}); - } - } - ); - - if (last) - { - return send_req(); - } - - return {ThreadOutcome::Continue, EventOutcome::Consume}; - } - else - { - return {}; - } - - } - - auto on_msg(ircmsg const& msg) -> IrcThread::callback_result override - { - switch (stage_) - { - case Stage::LsReply: return capls(msg); - case Stage::AckReply: return capack(msg); - default: return {}; - } + return {}; } }; @@ -262,19 +79,20 @@ auto start(boost::asio::io_context & io, Settings const& settings) -> void { auto connection = std::make_shared(io); - connection->listen(std::make_unique(connection.get())); - connection->listen(std::make_unique(connection.get(), settings.password, settings.username, settings.realname, settings.nickname)); - connection->listen(std::make_unique()); - connection->listen(std::make_unique()); + connection->add_thread(std::make_shared(connection.get())); + connection->add_thread(std::make_shared(connection.get())); + connection->add_thread(std::make_shared(connection.get(), settings.password, settings.username, settings.realname, settings.nickname)); + connection->add_thread(std::make_shared()); + connection->add_thread(std::make_shared()); boost::asio::co_spawn( io, - connection->connect(io, settings), + connection->connect(io, settings.host, settings.service), [&io, &settings](std::exception_ptr e) { - auto timer = boost::asio::steady_timer{io}; - timer.expires_from_now(5s); - timer.async_wait([&io, &settings](auto) { + auto timer = std::make_shared(io); + timer->expires_from_now(5s); + timer->async_wait([&io, &settings, timer](auto) { start(io, settings); }); }); diff --git a/ping_thread.cpp b/ping_thread.cpp new file mode 100644 index 0000000..35bf745 --- /dev/null +++ b/ping_thread.cpp @@ -0,0 +1,25 @@ +#include "ping_thread.hpp" + +#include "irc_parse_thread.hpp" +#include "write_irc.hpp" + +PingThread::PingThread(Connection * connection) noexcept : connection_{connection} {} + +auto PingThread::priority() const -> priority_type +{ + return 1; +} + +auto PingThread::on_event(Event const& event) -> std::pair +{ + if (auto const irc_event = dynamic_cast(&event)) + { + auto& irc = irc_event->irc; + if ("PING" == irc.command && 1 == irc.args.size()) + { + write_irc(*connection_, "PONG", irc.args[0]); + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + } + return {}; +} diff --git a/ping_thread.hpp b/ping_thread.hpp new file mode 100644 index 0000000..b7eae77 --- /dev/null +++ b/ping_thread.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "connection.hpp" +#include "thread.hpp" + +class PingThread : public Thread +{ + Connection * connection_; + +public: + PingThread(Connection * connection) noexcept; + + auto priority() const -> priority_type override; + auto on_event(Event const& event) -> std::pair override; +}; diff --git a/registration_thread.cpp b/registration_thread.cpp new file mode 100644 index 0000000..dc1279b --- /dev/null +++ b/registration_thread.cpp @@ -0,0 +1,164 @@ +#include "registration_thread.hpp" + +RegistrationThread::RegistrationThread( + Connection * connection_, + std::string password, + std::string username, + std::string realname, + std::string nickname +) + : connection_{connection_} + , password_{password} + , username_{username} + , realname_{realname} + , nickname_{nickname} + , stage_{Stage::LsReply} +{} + +auto RegistrationThread::priority() const -> priority_type +{ + return 2; +} + +auto RegistrationThread::on_connect() -> Thread::callback_result +{ + write_irc(*connection_, "CAP", "LS", "302"); + write_irc(*connection_, "PASS", password_); + write_irc(*connection_, "USER", username_, "*", "*", realname_); + write_irc(*connection_, "NICK", nickname_); + return {}; +} + +auto RegistrationThread::send_req() -> Thread::callback_result +{ + std::string request; + char const* want[] = { "extended-join", "account-notify", "draft/chathistory", "batch", "soju.im/no-implicit-names", "chghost", "setname", "account-tag", "solanum.chat/oper", "solanum.chat/identify-msg", "solanum.chat/realhost", "server-time", "invite-notify", "extended-join" }; + for (auto cap : want) + { + if (caps.contains(cap)) + { + request.append(cap); + request.push_back(' '); + outstanding.insert(cap); + } + } + if (not outstanding.empty()) + { + request.pop_back(); + write_irc(*connection_, "CAP", "REQ", request); + stage_ = Stage::AckReply; + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + else + { + write_irc(*connection_, "CAP", "END"); + return {ThreadOutcome::Finish, EventOutcome::Consume}; + } +} + +auto RegistrationThread::capack(IrcMsg const& msg) -> Thread::callback_result +{ + auto const n = msg.args.size(); + if ("CAP" == msg.command && n >= 2 && "*" == msg.args[0] && "ACK" == msg.args[1]) + { + auto in = std::istringstream{std::string{msg.args[2]}}; + std::for_each( + std::istream_iterator{in}, + std::istream_iterator{}, + [this](std::string x) { + outstanding.erase(x); + } + ); + if (outstanding.empty()) + { + write_irc(*connection_, "CAP", "END"); + return {ThreadOutcome::Finish, EventOutcome::Consume}; + } + else + { + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + } + else + { + return {}; + } +} + +auto RegistrationThread::capls(IrcMsg const& msg) -> Thread::callback_result +{ + auto const n = msg.args.size(); + if ("CAP" == msg.command && n >= 2 && "*" == msg.args[0] && "LS" == msg.args[1]) + { + std::string_view const* kvs; + bool last; + + if (3 == n) + { + kvs = &msg.args[2]; + last = true; + } + else if (4 == n && "*" == msg.args[2]) + { + kvs = &msg.args[3]; + last = false; + } + else + { + return {}; + } + + auto in = std::istringstream{std::string{*kvs}}; + + std::for_each( + std::istream_iterator{in}, + std::istream_iterator{}, + [this](std::string x) { + auto const eq = x.find('='); + if (eq == x.npos) + { + caps.emplace(x, std::string{}); + } + else + { + caps.emplace(std::string{x, 0, eq}, std::string{x, eq+1, x.npos}); + } + } + ); + + if (last) + { + return send_req(); + } + + return {ThreadOutcome::Continue, EventOutcome::Consume}; + } + else + { + return {}; + } + +} + +auto RegistrationThread::on_msg(IrcMsg const& msg) -> Thread::callback_result +{ + switch (stage_) + { + case Stage::LsReply: return capls(msg); + case Stage::AckReply: return capack(msg); + default: return {}; + } +} + +auto RegistrationThread::on_event(Event const& event) -> Thread::callback_result +{ + if (auto const irc_event = dynamic_cast(&event)) + { + return on_msg(irc_event->irc); + } + if (auto const connect_event = dynamic_cast(&event)) + { + return on_connect(); + } + return {}; +} diff --git a/registration_thread.hpp b/registration_thread.hpp new file mode 100644 index 0000000..afd1050 --- /dev/null +++ b/registration_thread.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "thread.hpp" +#include "connection.hpp" +#include "irc_parse_thread.hpp" +#include "write_irc.hpp" + +#include +#include +#include + +struct RegistrationThread : Thread +{ + Connection * connection_; + std::string password_; + std::string username_; + std::string realname_; + std::string nickname_; + + std::unordered_map caps; + std::unordered_set outstanding; + + enum class Stage + { + LsReply, + AckReply, + }; + + Stage stage_; + + RegistrationThread( + Connection * connection_, + std::string password, + std::string username, + std::string realname, + std::string nickname + ); + + auto priority() const -> priority_type override; + auto on_connect() -> Thread::callback_result; + + auto send_req() -> Thread::callback_result; + + auto capack(IrcMsg const& msg) -> Thread::callback_result; + + auto capls(IrcMsg const& msg) -> Thread::callback_result; + + auto on_msg(IrcMsg const& msg) -> Thread::callback_result; + + auto on_event(Event const& event) -> Thread::callback_result override; +}; \ No newline at end of file diff --git a/thread.cpp b/thread.cpp new file mode 100644 index 0000000..1f1c4bf --- /dev/null +++ b/thread.cpp @@ -0,0 +1,46 @@ +#include "thread.hpp" + +auto Dispatcher::add_thread(std::shared_ptr thread) -> void +{ + threads_.push_back(std::move(thread)); +} + +auto Dispatcher::dispatch(std::shared_ptr event) -> void +{ + if (dispatching_) + { + events_.push_back(std::move(event)); + return; + } + + dispatching_ = true; + std::vector> events{std::move(event)}; + while (not events.empty()) + { + for (auto && event : events) + { + std::vector> work; + work.swap(threads_); + std::sort(work.begin(), work.end(), [](auto const& a, auto const& b) { return a->priority() < b->priority(); }); + + std::size_t const n = work.size(); + for (std::size_t i = 0; i < n; i++) + { + auto const [thread_outcome, msg_outcome] = work[i]->on_event(*event); + if (thread_outcome == ThreadOutcome::Continue) + { + threads_.push_back(std::move(work[i])); + } + if (msg_outcome == EventOutcome::Consume) + { + std::move(work.begin() + i + 1, work.end(), std::back_inserter(threads_)); + break; + } + } + } + + events = std::move(events_); + events_.clear(); + } + dispatching_ = false; +} \ No newline at end of file diff --git a/thread.hpp b/thread.hpp new file mode 100644 index 0000000..0f696d9 --- /dev/null +++ b/thread.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "ircmsg.hpp" + +#include +#include + +enum class EventOutcome +{ + Pass, + Consume, +}; + +enum class ThreadOutcome +{ + Continue, + Finish, +}; + +struct Event { + virtual ~Event() {} +}; + +struct Thread +{ + using priority_type = std::uint64_t; + using callback_result = std::pair; + virtual ~Thread() {} + virtual auto on_event(Event const& event) -> callback_result { return {}; }; + virtual auto priority() const -> priority_type = 0; +}; + +struct Dispatcher +{ + std::vector> threads_; + std::vector> events_; + bool dispatching_; + + /// Apply a function to all the threads in priority order + auto dispatch(std::shared_ptr event) -> void; + auto add_thread(std::shared_ptr thread) -> void; +}; \ No newline at end of file diff --git a/watchdog_thread.cpp b/watchdog_thread.cpp new file mode 100644 index 0000000..4d479d6 --- /dev/null +++ b/watchdog_thread.cpp @@ -0,0 +1,42 @@ +#include "watchdog_thread.hpp" + +#include "connection.hpp" +#include "irc_parse_thread.hpp" + +#include + +using namespace std::chrono_literals; + +WatchdogThread::WatchdogThread(Connection * connection) noexcept +: connection_{connection} +{ +} + +auto WatchdogThread::priority() const -> priority_type +{ + return 0; +} + +auto WatchdogThread::on_event(Event const& event) -> std::pair +{ + if (auto const irc_event = dynamic_cast(&event)) + { + timer_.expires_from_now(30s); + return {}; + } + if (auto const connect_event = dynamic_cast(&event)) + { + timer_.expires_from_now(30s); + timer_.async_wait([weak = weak_from_this()](auto error) + { + if (not error) + { + if (auto self = weak.lock()) + { + self->connection_->close(); + } + } + }); + } + return {}; +} diff --git a/watchdog_thread.hpp b/watchdog_thread.hpp new file mode 100644 index 0000000..5923876 --- /dev/null +++ b/watchdog_thread.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "thread.hpp" +#include + +class Connection; + +class WatchdogThread : public Thread, public std::enable_shared_from_this +{ + Connection * connection_; + boost::asio::steady_timer timer_; + +public: + WatchdogThread(Connection * connection) noexcept; + + auto priority() const -> priority_type override; + auto on_event(Event const& event) -> std::pair override; +}; diff --git a/write_irc.cpp b/write_irc.cpp new file mode 100644 index 0000000..3ad5c8f --- /dev/null +++ b/write_irc.cpp @@ -0,0 +1,23 @@ +#include "write_irc.hpp" + +auto write_irc(Connection& connection, std::string message) -> void +{ + message += "\r\n"; + connection.write_raw(std::move(message)); +} + +auto write_irc(Connection& connection, std::string front, std::string_view last) -> void +{ + auto const is_invalid = [](char x) -> bool { + return x == '\0' || x == '\r' || x == '\n'; + }; + + if (last.end() != std::find_if(last.begin(), last.end(), is_invalid)) + { + throw std::runtime_error{"bad irc argument"}; + } + + front += " :"; + front += last; + write_irc(connection, std::move(front)); +} diff --git a/write_irc.hpp b/write_irc.hpp new file mode 100644 index 0000000..f47977d --- /dev/null +++ b/write_irc.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "connection.hpp" + +auto write_irc(Connection& connection, std::string message) -> void; + +auto write_irc(Connection& connection, std::string front, std::string_view last) -> void; + +template +auto write_irc(Connection& connection, std::string front, std::string_view next, Args ...rest) -> void +{ + auto const is_invalid = [](char x) -> bool { + return x == '\0' || x == '\r' || x == '\n' || x == ' '; + }; + + if (next.empty() || next.end() != std::find_if(next.begin(), next.end(), is_invalid)) + { + throw std::runtime_error{"bad irc argument"}; + } + + front += " "; + front += next; + write_irc(connection, std::move(front), rest...); +} \ No newline at end of file