From 3b48ff7c7e18fc77544ecadec8675ee2490346e9 Mon Sep 17 00:00:00 2001 From: Eric Mertens Date: Fri, 24 Jan 2025 14:48:15 -0800 Subject: [PATCH] checkpoint --- CMakeLists.txt | 1 - connection.cpp | 53 +++++++++++--- connection.hpp | 11 ++- irc_coroutine.cpp | 22 ++---- irc_coroutine.hpp | 173 +++++++++++++++++++++++++++++++++++--------- main.cpp | 10 +-- self_thread.cpp | 61 ++++++++-------- self_thread.hpp | 7 +- snote.cpp | 9 +-- snote.hpp | 26 +++---- watchdog_thread.cpp | 78 -------------------- watchdog_thread.hpp | 47 ------------ 12 files changed, 253 insertions(+), 245 deletions(-) delete mode 100644 watchdog_thread.cpp delete mode 100644 watchdog_thread.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b95ec0..7782d6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,7 +35,6 @@ add_executable(xbot snote.cpp self_thread.cpp irc_coroutine.cpp - watchdog_thread.cpp ) target_include_directories(xbot PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/connection.cpp b/connection.cpp index f8746df..fa30e1e 100644 --- a/connection.cpp +++ b/connection.cpp @@ -8,9 +8,13 @@ namespace { #include "irc_commands.inc" } // namespace +using namespace std::literals; + Connection::Connection(boost::asio::io_context & io) : stream_{io} , write_timer_{io, std::chrono::steady_clock::time_point::max()} +, watchdog_timer_{io} +, stalled_{false} { } @@ -25,16 +29,11 @@ auto Connection::writer_immediate() -> void boost::asio::async_write( stream_, buffers, - [weak = weak_from_this() - ,strings = std::move(write_strings_) - ](boost::system::error_code const& error, std::size_t) + [this, strings = std::move(write_strings_)](boost::system::error_code const& error, std::size_t) { if (not error) { - if (auto self = weak.lock()) - { - self->writer(); - } + writer(); } }); write_strings_.clear(); @@ -45,6 +44,10 @@ auto Connection::writer() -> void if (write_strings_.empty()) { write_timer_.async_wait([weak = weak_from_this()](auto){ + // This wait will always trigger on a cancellation. That + // cancellation might be from write_line or it might be from + // the connection being destroyed. The weak pointer will fail + // to lock in the case that the object is being destructed. if (auto self = weak.lock()) { if (not self->write_strings_.empty()) @@ -80,7 +83,8 @@ auto Connection::connect( } // Start the queue writer after connection - self->writer(); + writer(); + watchdog(); for(LineBuffer buffer{32'768};;) { @@ -92,13 +96,44 @@ auto Connection::connect( } buffer.add_bytes(n, [this](char * line) { BOOST_LOG_TRIVIAL(debug) << "RECV: " << line; + watchdog_activity(); dispatch_line(line); }); } + watchdog_timer_.cancel(); + stream_.close(); sig_disconnect(); } +auto Connection::watchdog() -> void +{ + watchdog_timer_.expires_after(watchdog_duration); + watchdog_timer_.async_wait([this](auto const& error) + { + if (not error) + { + if (stalled_) + { + BOOST_LOG_TRIVIAL(debug) << "Watchdog timer elapsed, closing stream"; + close(); + } + else + { + send_ping("watchdog"); + stalled_ = true; + watchdog(); + } + } + }); +} + +auto Connection::watchdog_activity() -> void +{ + stalled_ = false; + watchdog_timer_.expires_after(watchdog_duration); +} + /// Parse IRC message line and dispatch it to the ircmsg slot. auto Connection::dispatch_line(char *line) -> void { @@ -126,7 +161,7 @@ auto Connection::dispatch_line(char *line) -> void // Server notice generate snote events but not IRC command events case IrcCommand::NOTICE: if (auto match = snoteCore.match(msg)) { - sig_snote(match->first, match->second); + sig_snote(*match); break; } /* FALLTHROUGH */ diff --git a/connection.hpp b/connection.hpp index a98d00f..162fce4 100644 --- a/connection.hpp +++ b/connection.hpp @@ -16,12 +16,21 @@ class Connection : public std::enable_shared_from_this private: boost::asio::ip::tcp::socket stream_; boost::asio::steady_timer write_timer_; + boost::asio::steady_timer watchdog_timer_; std::list write_strings_; + // Set true when watchdog triggers. + // Set false when message received. + bool stalled_; + auto writer() -> void; auto writer_immediate() -> void; auto dispatch_line(char * line) -> void; + static constexpr std::chrono::seconds watchdog_duration = std::chrono::seconds{30}; + auto watchdog() -> void; + auto watchdog_activity() -> void; + /// Write bytes into the socket. Messages should be properly newline terminated. auto write_line(std::string message) -> void; @@ -37,7 +46,7 @@ public: boost::signals2::signal sig_connect; boost::signals2::signal sig_disconnect; boost::signals2::signal sig_ircmsg; - boost::signals2::signal sig_snote; + boost::signals2::signal sig_snote; auto get_executor() -> boost::asio::any_io_executor { return stream_.get_executor(); diff --git a/irc_coroutine.cpp b/irc_coroutine.cpp index 11a070a..9269dca 100644 --- a/irc_coroutine.cpp +++ b/irc_coroutine.cpp @@ -12,20 +12,10 @@ auto irc_coroutine::start(Connection& connection) -> void { resume(); } -void wait_command::await_suspend(std::coroutine_handle handle) -{ - auto &connection = *handle.promise().connection_; - ircmsg_connection_ = connection.sig_ircmsg.connect([this, handle](auto cmd, auto &msg) { - auto const wanted = std::find(want_cmds_.begin(), want_cmds_.end(), cmd) != want_cmds_.end(); - if (wanted) { - unsubscribe(); - resultCmd = cmd; - resultMsg = &msg; - handle.resume(); - } - }); - disconnect_connection_ = connection.sig_disconnect.connect([this, handle]() { - unsubscribe(); - handle.resume(); - }); +void wait_ircmsg::stop() { + ircmsg_slot_.disconnect(); +} + +void wait_timeout::stop() { + timer_.reset(); } diff --git a/irc_coroutine.hpp b/irc_coroutine.hpp index 22097ae..4866326 100644 --- a/irc_coroutine.hpp +++ b/irc_coroutine.hpp @@ -2,16 +2,25 @@ #include "connection.hpp" +#include #include +#include +#include #include struct irc_promise; +/// A coroutine that can co_await on various IRC events struct irc_coroutine : std::coroutine_handle { using promise_type = irc_promise; + /// Start the coroutine and associate it with a specific connection. auto start(Connection &connection) -> void; + + /// Returns true when this coroutine is still waiting on events auto is_running() -> bool; + + /// Returns the exception that terminated this coroutine, if there is one. auto exception() -> std::exception_ptr; }; @@ -23,75 +32,169 @@ struct irc_promise // Pointer to exception that terminated this coroutine if there is one. std::exception_ptr exception_; - irc_coroutine get_return_object() + auto get_return_object() -> irc_coroutine { return {irc_coroutine::from_promise(*this)}; } // Suspend waiting for start() to initialize connection_ - std::suspend_always initial_suspend() noexcept { return {}; } + auto initial_suspend() noexcept -> std::suspend_always { return {}; } // Suspend so that is_running() and exception() work - std::suspend_always final_suspend() noexcept { return {}; } + auto final_suspend() noexcept -> std::suspend_always { return {}; } // Normal termination - void return_void() { + auto return_void() -> void { connection_.reset(); } // Abnormal termination - remember the exception - void unhandled_exception() { + auto unhandled_exception() -> void { connection_.reset(); exception_ = std::current_exception(); } }; -/* -struct wait_ircmsg { - using result_type = std::pair +template +class Wait; + +/// Argument to a Wait that expects one or more IRC messages +class wait_ircmsg { + // Vector of commands this wait is expecting. Leave empty to accept all messages. std::vector want_cmds_; + + // Slot for the ircmsg event + boost::signals2::scoped_connection ircmsg_slot_; + +public: + using result_type = std::pair; + + wait_ircmsg(std::initializer_list want_cmds) : want_cmds_{want_cmds} {} + + template auto start(Wait& command) -> void; + auto stop() -> void; }; -struct wait_timeout { +class wait_timeout { + std::optional timer_; + std::chrono::milliseconds timeout_; + +public: struct result_type {}; - std::vector want_cmds_; + wait_timeout(std::chrono::milliseconds timeout) : timeout_{timeout} {} + + template auto start(Wait& command) -> void; + auto stop() -> void; }; -*/ +template +class Wait { -class wait_command { - std::vector want_cmds_; + // State associated with each wait mode + std::tuple modes_; - IrcCommand resultCmd; - const IrcMsg *resultMsg; + // Result from any one of the wait modes + std::optional> result_; - boost::signals2::scoped_connection ircmsg_connection_; - boost::signals2::scoped_connection disconnect_connection_; + // Handle of the continuation to be resumed when one of the wait + // modes is ready. + std::coroutine_handle handle_; - void unsubscribe() { - ircmsg_connection_.disconnect(); - disconnect_connection_.disconnect(); + // Slot for tearing down the irc_coroutine in case the connection + // fails before any wait modes complete. + boost::signals2::scoped_connection disconnect_slot_; + + template + auto start_mode() -> void { + std::get(modes_).template start(*this); + } + + template + auto start_modes(std::index_sequence) -> void { + (start_mode(), ...); + } + + template + auto stop_modes(std::index_sequence) -> void { + (std::get(modes_).stop(), ...); } public: - wait_command(std::initializer_list want_cmds) - : want_cmds_(want_cmds) - , resultMsg{} - {} + Wait(Ts &&...modes) : modes_{std::forward(modes)...} {} - /// The coroutine always needs to wait for a message. It will never - /// be ready immediately. - bool await_ready() noexcept { return false; } + // Get the connection that this coroutine was started with. + auto get_connection() const -> Connection & { + return *handle_.promise().connection_; + } + + // Store a successful result and resume the coroutine + template + auto complete(Args &&...args) -> void { + result_.emplace(std::in_place_index, std::forward(args)...); + handle_.resume(); + } + + // The coroutine always needs to wait for a message. It will never + // be ready immediately. + auto await_ready() noexcept -> bool { return false; } /// Install event handles in the connection that will resume this coroutine. - void await_suspend(std::coroutine_handle handle); + auto await_suspend(std::coroutine_handle handle) -> void; - auto await_resume() -> std::pair { - if (resultMsg) { - return std::make_pair(resultCmd, std::cref(*resultMsg)); - } else { - throw std::runtime_error{"connection terminated"}; - } - } + auto await_resume() -> std::variant; }; +template +auto wait_ircmsg::start(Wait &command) -> void +{ + ircmsg_slot_ = command.get_connection().sig_ircmsg.connect([this, &command](auto cmd, auto &msg) { + auto const wanted = + want_cmds_.empty() || + std::find(want_cmds_.begin(), want_cmds_.end(), cmd) != want_cmds_.end(); + if (wanted) { + command.template complete(cmd, msg); + } + }); +} + +template +auto wait_timeout::start(Wait& command) -> void +{ + timer_.emplace(command.get_connection().get_executor()); + timer_->expires_after(timeout_); + timer_->async_wait([this, &command](auto const& error) + { + if (not error) { + timer_.reset(); + command.template complete(); + } + }); +} + +template +auto Wait::await_suspend(std::coroutine_handle handle) -> void +{ + handle_ = handle; + + auto const tuple_size = std::tuple_size_v; + start_modes(std::make_index_sequence{}); + + disconnect_slot_ = get_connection().sig_disconnect.connect([this]() { + handle_.resume(); + }); +} + +template +auto Wait::await_resume() -> std::variant +{ + auto const tuple_size = std::tuple_size_v; + stop_modes(std::make_index_sequence{}); + + disconnect_slot_.disconnect(); + + if (result_) { + return std::move(*result_); + } else { + throw std::runtime_error{"connection terminated"}; + } +} diff --git a/main.cpp b/main.cpp index 8af2f60..5500b8f 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,4 @@ #include "connection.hpp" -#include "ircmsg.hpp" #include "settings.hpp" #include @@ -8,13 +7,10 @@ #include #include #include -#include #include "registration_thread.hpp" #include "self_thread.hpp" -#include "irc_coroutine.hpp" - using namespace std::chrono_literals; auto start(boost::asio::io_context & io, Settings const& settings) -> void @@ -22,10 +18,10 @@ auto start(boost::asio::io_context & io, Settings const& settings) -> void auto const connection = std::make_shared(io); RegistrationThread::start(*connection, settings.password, settings.username, settings.realname, settings.nickname); - SelfThread::start(*connection); + auto selfThread = SelfThread::start(*connection); - connection->sig_snote.connect([](auto tag, auto &match) { - std::cout << "SNOTE " << static_cast(tag) << std::endl; + connection->sig_snote.connect([](auto &match) { + std::cout << "SNOTE " << static_cast(match.get_tag()) << std::endl; for (auto c : match.get_results()) { std::cout << " " << std::string_view{c.first, c.second} << std::endl; diff --git a/self_thread.cpp b/self_thread.cpp index 6681ed0..39bf5ef 100644 --- a/self_thread.cpp +++ b/self_thread.cpp @@ -2,6 +2,8 @@ #include "connection.hpp" +#include + auto SelfThread::on_welcome(IrcMsg const& irc) -> void { nickname_ = irc.args[0]; @@ -78,6 +80,25 @@ auto SelfThread::on_mode(IrcMsg const& irc) -> void } } +auto SelfThread::on_isupport(const IrcMsg &msg) -> void +{ + auto const hi = msg.args.size() - 1; + for (int i = 1; i < hi; ++i) + { + auto &entry = msg.args[i]; + if (entry.starts_with("-")) { + auto key = std::string{entry.substr(1)}; + if (auto cursor = isupport_.find(key); cursor != isupport_.end()) { + isupport_.erase(cursor); + } + } else if (auto cursor = entry.find('='); cursor != entry.npos) { + isupport_.emplace(entry.substr(0, cursor), entry.substr(cursor+1)); + } else { + isupport_.emplace(entry, std::string{}); + } + } +} + auto SelfThread::start(Connection& connection) -> std::shared_ptr { auto thread = std::make_shared(connection); @@ -86,38 +107,14 @@ auto SelfThread::start(Connection& connection) -> std::shared_ptr { switch (cmd) { - // Learn nickname from 001 - case IrcCommand::RPL_WELCOME: - thread->on_welcome(msg); - break; - - // Track changes to our nickname - case IrcCommand::NICK: - thread->on_nick(msg); - break; - - // Re-establish user modes - case IrcCommand::RPL_UMODEIS: - thread->on_umodeis(msg); - break; - - case IrcCommand::JOIN: - thread->on_join(msg); - break; - - case IrcCommand::KICK: - thread->on_kick(msg); - break; - - case IrcCommand::PART: - thread->on_part(msg); - break; - - // Interpret self mode changes - case IrcCommand::MODE: - thread->on_mode(msg); - break; - + case IrcCommand::RPL_WELCOME: thread->on_welcome(msg); break; + case IrcCommand::RPL_ISUPPORT: thread->on_isupport(msg); break; + case IrcCommand::RPL_UMODEIS: thread->on_umodeis(msg); break; + case IrcCommand::NICK: thread->on_nick(msg); break; + case IrcCommand::JOIN: thread->on_join(msg); break; + case IrcCommand::KICK: thread->on_kick(msg); break; + case IrcCommand::PART: thread->on_part(msg); break; + case IrcCommand::MODE: thread->on_mode(msg); break; default: break; } }); diff --git a/self_thread.hpp b/self_thread.hpp index e27394c..7c154e5 100644 --- a/self_thread.hpp +++ b/self_thread.hpp @@ -1,6 +1,9 @@ #pragma once -#include +#include "connection.hpp" + +#include + #include #include @@ -17,8 +20,10 @@ class SelfThread std::string nickname_; std::string mode_; std::unordered_set channels_; + std::unordered_map isupport_; auto on_welcome(IrcMsg const& irc) -> void; + auto on_isupport(IrcMsg const& irc) -> void; auto on_nick(IrcMsg const& irc) -> void; auto on_umodeis(IrcMsg const& irc) -> void; auto on_join(IrcMsg const& irc) -> void; diff --git a/snote.cpp b/snote.cpp index 67b02f3..eec93f1 100644 --- a/snote.cpp +++ b/snote.cpp @@ -114,7 +114,7 @@ static auto setup_database() -> hs_database_t* } // namespace -SnoteCore::SnoteCore() noexcept +SnoteCore::SnoteCore() { db_.reset(setup_database()); @@ -126,7 +126,7 @@ SnoteCore::SnoteCore() noexcept scratch_.reset(scratch); } -auto SnoteCore::match(const IrcMsg &msg) -> std::optional> +auto SnoteCore::match(const IrcMsg &msg) -> std::optional { static char const* const prefix = "*** Notice -- "; @@ -162,7 +162,7 @@ auto SnoteCore::match(const IrcMsg &msg) -> std::optional std::match_results(components_).first; - auto message = std::get<0>(components_).second; + auto [regex, message] = std::get<0>(components_); auto& results = components_.emplace<1>(); if (not std::regex_match(message.begin(), message.end(), results, regex)) { diff --git a/snote.hpp b/snote.hpp index d88acfa..28b8e71 100644 --- a/snote.hpp +++ b/snote.hpp @@ -3,13 +3,12 @@ #include "ircmsg.hpp" #include +#include #include #include +#include #include -#include -#include -class Connection; struct hs_database; struct hs_scratch; @@ -34,16 +33,17 @@ enum class SnoteTag class SnoteMatch { -public: - SnoteMatch(std::regex const& regex, std::string_view full) - : components_{std::make_pair(std::ref(regex), full)} - {} - - auto get_results() -> std::match_results const&; - -private: SnoteTag tag_; std::variant, std::match_results> components_; + +public: + SnoteMatch(SnoteTag tag, std::regex const& regex, std::string_view full) + : tag_{tag} + , components_{std::make_pair(std::ref(regex), full)} + {} + + auto get_tag() -> SnoteTag { return tag_; } + auto get_results() -> std::match_results const&; }; struct SnoteCore @@ -64,8 +64,8 @@ struct SnoteCore /// @brief HyperScan scratch space std::unique_ptr scratch_; - SnoteCore() noexcept; - auto match(const IrcMsg &msg) -> std::optional>; + SnoteCore(); + auto match(const IrcMsg &msg) -> std::optional; }; extern SnoteCore snoteCore; diff --git a/watchdog_thread.cpp b/watchdog_thread.cpp deleted file mode 100644 index 544722c..0000000 --- a/watchdog_thread.cpp +++ /dev/null @@ -1,78 +0,0 @@ -#include "watchdog_thread.hpp" - -#include "connection.hpp" - -#include - -#include -#include - -WatchdogThread::WatchdogThread(Connection& connection) -: connection_{connection} -, timer_{connection.get_executor()} -, stalled_{false} -{ -} - -auto WatchdogThread::on_activity() -> void -{ - stalled_ = false; - timer_.expires_after(WatchdogThread::TIMEOUT); -} - -auto WatchdogThread::start_timer() -{ - timer_.async_wait([weak = weak_from_this()](auto const& error) - { - if (not error) - { - if (auto self = weak.lock()) - { - self->on_timeout(); - } - } - }); -} - -auto WatchdogThread::on_timeout() -> void -{ - if (stalled_) - { - connection_.close(); - } - else - { - connection_.send_ping("watchdog"); - stalled_ = true; - timer_.expires_after(WatchdogThread::TIMEOUT); - start_timer(); - } -} - -auto WatchdogThread::on_connect() -> void -{ - on_activity(); - start_timer(); -} - -auto WatchdogThread::on_disconnect() -> void -{ - timer_.cancel(); -} - -auto WatchdogThread::start(Connection& connection) -> void -{ - auto const thread = std::make_shared(connection); - connection.sig_connect.connect([thread]() - { - thread->on_connect(); - }); - connection.sig_disconnect.connect([thread]() - { - thread->on_disconnect(); - }); - connection.sig_ircmsg.connect([thread](auto, auto&) - { - thread->on_activity(); - }); -} diff --git a/watchdog_thread.hpp b/watchdog_thread.hpp deleted file mode 100644 index cba66d8..0000000 --- a/watchdog_thread.hpp +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -#include - -#include -#include - -class Connection; - -/** - * @brief Watch for connection activity and disconnect on stall - * - * The thread will send a ping if no message is received in the - * last TIMEOUT seconds. After another period of no messages - * the thread will disconnect the connection. - * - */ -class WatchdogThread : std::enable_shared_from_this -{ - Connection& connection_; - boost::asio::steady_timer timer_; - - /// @brief Set true and ping sent and false when reply received - bool stalled_; - - const std::chrono::steady_clock::duration TIMEOUT = std::chrono::seconds{30}; - - /// @brief Start the timer - /// @return - auto start_timer(); - - /// @brief - auto on_activity() -> void; - - /// @brief - auto on_timeout() -> void; - - /// @brief callback for ConnectEvent event - auto on_connect() -> void; - - /// @brief callback for DisconnectEvent event - auto on_disconnect() -> void; - -public: - WatchdogThread(Connection& connection); - static auto start(Connection& connection) -> void; -};