diff --git a/driver/main.cpp b/driver/main.cpp index c2371be..522f345 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -37,7 +37,6 @@ auto configure_sasl(const Settings &settings) -> std::unique_ptr settings.sasl_mechanism == "ECDSA" && not settings.sasl_authcid.empty() && not settings.sasl_key_file.empty() - ) { if (auto sasl_key = key_from_file(settings.sasl_key_file, settings.sasl_key_password)) return std::make_unique( @@ -62,18 +61,17 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void tls_key = key_from_file(settings.tls_key_file, settings.tls_key_password); } - auto sasl_mech = configure_sasl(settings); - const auto connection = std::make_shared(io); const auto client = Client::start(connection); + const auto bot = Bot::start(client); + Registration::start({ .nickname = settings.nickname, .realname = settings.realname, .username = settings.username, .password = settings.password, - .sasl_mechanism = std::move(sasl_mech), + .sasl_mechanism = configure_sasl(settings), }, client); - const auto bot = Bot::start(client); // Configure CHALLENGE on registration if applicable if (not settings.challenge_username.empty() && not settings.challenge_key_file.empty()) { @@ -84,7 +82,7 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void } } - // On disconnect tear down the various layers and reconnect in 5 seconds + // On disconnect reconnect in 5 seconds // connection is captured in the disconnect handler so it can keep itself alive connection->sig_disconnect.connect( [&io, &settings, connection]() { diff --git a/myirc/CMakeLists.txt b/myirc/CMakeLists.txt index 1a31bbe..a1b40f8 100644 --- a/myirc/CMakeLists.txt +++ b/myirc/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(myirc STATIC ircmsg.cpp openssl_utils.cpp registration.cpp + ratelimit.cpp sasl_mechanism.cpp snote.cpp ) diff --git a/myirc/connection.cpp b/myirc/connection.cpp index a518905..84f7793 100644 --- a/myirc/connection.cpp +++ b/myirc/connection.cpp @@ -1,5 +1,6 @@ #include "connection.hpp" +#include "boost/asio/steady_timer.hpp" #include "linebuffer.hpp" #include @@ -25,16 +26,51 @@ Connection::Connection(boost::asio::io_context &io) auto Connection::write_buffers() -> void { + const auto available = write_strings_.size(); + const auto [delay, count] + = rate_limit + ? rate_limit->query(available) + : std::pair{0ms, available}; + + if (delay > 0ms) { + auto timer = std::make_shared(stream_.get_executor(), delay); + timer->async_wait([timer, count, self = weak_from_this()](auto) { + if (auto lock = self.lock()) { + lock->write_buffers(count); + } + }); + } else { + write_buffers(count); + } +} + +auto Connection::write_buffers(size_t n) -> void +{ + std::list strings; std::vector buffers; - buffers.reserve(write_strings_.size()); - for (const auto &elt : write_strings_) + + if (n == write_strings_.size()) { + strings = std::move(write_strings_); + write_strings_.clear(); + } else { + strings.splice( + strings.begin(), // insert at + write_strings_, // remove from + write_strings_.begin(), // start removing at + std::next(write_strings_.begin(), n) // stop removing at + ); + } + + buffers.reserve(n); + for (const auto &elt : strings) { buffers.push_back(boost::asio::buffer(elt)); } + boost::asio::async_write( stream_, buffers, - [this, strings = std::move(write_strings_)](const boost::system::error_code &error, std::size_t) { + [this, strings = std::move(strings)](const boost::system::error_code &error, std::size_t) { if (not error) { if (write_strings_.empty()) @@ -48,7 +84,6 @@ auto Connection::write_buffers() -> void } } ); - write_strings_.clear(); } auto Connection::watchdog() -> void @@ -154,14 +189,17 @@ auto Connection::write_irc(std::string message) -> void auto Connection::write_irc(std::string front, std::string_view last) -> void { - if (last.find_first_of("\r\n\0"sv) != last.npos) - { - throw std::runtime_error{"bad irc argument"}; + bool colon = last.starts_with(":"); + for (const auto c : last) { + switch (c) { + case '\r': case '\n': case '\0': throw std::runtime_error{"bad irc argument"}; + case ' ': colon = true; + default: break; + } } - - front += " :"; + front += colon ? " :" : " "; front += last; - write_irc(std::move(front)); + write_line(std::move(front)); } auto Connection::send_ping(std::string_view txt) -> void diff --git a/myirc/include/connection.hpp b/myirc/include/connection.hpp index fb8d35a..59608bf 100644 --- a/myirc/include/connection.hpp +++ b/myirc/include/connection.hpp @@ -2,6 +2,7 @@ #include "irc_command.hpp" #include "ircmsg.hpp" +#include "ratelimit.hpp" #include "ref.hpp" #include "snote.hpp" #include "stream.hpp" @@ -46,7 +47,12 @@ private: // AUTHENTICATE support std::string authenticate_buffer_; + /// write buffers after consulting with rate limit auto write_buffers() -> void; + + /// write a specific number of messages now + auto write_buffers(size_t) -> void; + auto dispatch_line(char *line) -> void; static constexpr std::chrono::seconds watchdog_duration = std::chrono::seconds{30}; @@ -68,6 +74,7 @@ public: boost::signals2::signal sig_ircmsg; boost::signals2::signal sig_snote; boost::signals2::signal sig_authenticate; + std::unique_ptr rate_limit; Connection(boost::asio::io_context &io); diff --git a/myirc/include/ratelimit.hpp b/myirc/include/ratelimit.hpp new file mode 100644 index 0000000..2815024 --- /dev/null +++ b/myirc/include/ratelimit.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +struct RateLimit { + virtual ~RateLimit(); + auto virtual query(size_t want_to_send) -> std::pair = 0; +}; + +struct Rfc1459RateLimit final : RateLimit +{ + using clock = std::chrono::steady_clock; + + std::chrono::milliseconds cost_ {2'000}; + std::chrono::milliseconds allowance_ {10'000}; + clock::time_point horizon_{}; + + auto query(size_t want_to_send) -> std::pair override; +}; diff --git a/myirc/ratelimit.cpp b/myirc/ratelimit.cpp new file mode 100644 index 0000000..2f25813 --- /dev/null +++ b/myirc/ratelimit.cpp @@ -0,0 +1,23 @@ +#include "ratelimit.hpp" +#include + +using namespace std::literals; +using ms = std::chrono::milliseconds; + +auto Rfc1459RateLimit::query(size_t want_to_send) -> std::pair +{ + const auto now = clock::now(); + if (horizon_ < now) horizon_ = now; + + auto gap = std::chrono::floor(now + allowance_ - horizon_); + auto send = gap / cost_; + if (std::cmp_greater(send, want_to_send)) send = want_to_send; + + if (send > 0) { + horizon_ += send * cost_; + return {0ms, send}; + } else { + horizon_ += cost_; + return {cost_ - gap, 1}; + } +}