consolidate/simplify Registration's event loop

This commit is contained in:
Eric Mertens 2025-01-27 09:30:09 -08:00
parent eb7b27ebe3
commit b9d88bbd0c
7 changed files with 81 additions and 84 deletions

View File

@ -6,7 +6,7 @@ auto Bot::start(std::shared_ptr<Client> self) -> std::shared_ptr<Bot>
{
const auto thread = std::make_shared<Bot>(std::move(self));
thread->self_->get_connection()->sig_ircmsg.connect([thread](const auto cmd, auto &msg) {
thread->self_->get_connection().sig_ircmsg.connect([thread](const auto cmd, auto &msg) {
thread->on_ircmsg(cmd, msg);
});

View File

@ -56,6 +56,8 @@ public:
{
}
auto get_connection() -> Connection & { return connection_; }
static auto start(Connection &) -> std::shared_ptr<Client>;
auto start_sasl(std::unique_ptr<SaslMechanism> mechanism) -> void;

View File

@ -147,11 +147,6 @@ auto Connection::close() -> void
stream_.close();
}
static auto is_invalid_last(char x) -> bool
{
return x == '\0' || x == '\r' || x == '\n';
}
auto Connection::write_irc(std::string message) -> void
{
write_line(std::move(message));
@ -159,7 +154,7 @@ auto Connection::write_irc(std::string message) -> void
auto Connection::write_irc(std::string front, std::string_view last) -> void
{
if (last.end() != std::find_if(last.begin(), last.end(), is_invalid_last))
if (last.find_first_of("\r\n\0"sv) != last.npos)
{
throw std::runtime_error{"bad irc argument"};
}

View File

@ -2,7 +2,6 @@
#include "irc_command.hpp"
#include "ircmsg.hpp"
#include "settings.hpp"
#include "snote.hpp"
#include "stream.hpp"
@ -39,6 +38,7 @@ struct ConnectSettings
std::string socks_user;
std::string socks_pass;
};
class Connection : public std::enable_shared_from_this<Connection>
{
private:
@ -62,36 +62,34 @@ private:
auto watchdog_activity() -> void;
auto connect(ConnectSettings settings) -> boost::asio::awaitable<void>;
auto on_authenticate(std::string_view) -> void;
public:
/// Build and send well-formed IRC message from individual parameters
auto write_irc(std::string) -> void;
auto write_irc(std::string, std::string_view) -> void;
template <typename... Args>
auto write_irc(std::string front, std::string_view next, Args... rest) -> void;
/// Write bytes into the socket.
auto write_line(std::string message) -> void;
Connection(boost::asio::io_context &io);
public:
boost::signals2::signal<void()> sig_connect;
boost::signals2::signal<void()> sig_disconnect;
boost::signals2::signal<void(IrcCommand, const IrcMsg &)> sig_ircmsg;
boost::signals2::signal<void(SnoteMatch &)> sig_snote;
boost::signals2::signal<void(std::string_view)> sig_authenticate;
Connection(boost::asio::io_context &io);
/// Write bytes into the socket.
auto write_line(std::string message) -> void;
auto get_executor() -> boost::asio::any_io_executor
{
return stream_.get_executor();
}
auto start(ConnectSettings) -> void;
auto close() -> void;
auto on_authenticate(std::string_view) -> void;
auto send_ping(std::string_view) -> void;
auto send_pong(std::string_view) -> void;
auto send_pass(std::string_view) -> void;
@ -111,13 +109,9 @@ public:
template <typename... Args>
auto Connection::write_irc(std::string front, std::string_view next, Args... rest) -> void
{
const auto is_invalid = [](const char x) -> bool {
return x == '\0' || x == '\r' || x == '\n' || x == ' ';
};
using namespace std::literals;
if (next.empty()
|| next.front() == ':'
|| next.end() != std::find_if(next.begin(), next.end(), is_invalid))
if (next.empty() || next.front() == ':' || next.find_first_of("\r\n \0"sv) != next.npos)
{
throw std::runtime_error{"bad irc argument"};
}

View File

@ -7,7 +7,6 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "bot.hpp"
#include "client.hpp"
@ -19,7 +18,7 @@ auto start(boost::asio::io_context &io, const Settings &settings) -> void
{
const auto connection = std::make_shared<Connection>(io);
const auto client = Client::start(*connection);
Registration::start(*connection, settings, client);
Registration::start(settings, client);
const auto bot = Bot::start(client);

View File

@ -5,31 +5,35 @@
#include "sasl_mechanism.hpp"
#include <memory>
#include <random>
#include <unordered_map>
#include <unordered_set>
Registration::Registration(
Connection &connection,
const Settings &settings,
std::shared_ptr<Client> self
std::shared_ptr<Client> client
)
: connection_{connection}
, settings_{settings}
, self_{std::move(self)}
: settings_{settings}
, client_{std::move(client)}
{
}
auto Registration::on_connect() -> void
{
connection_.send_cap_ls();
listen_for_cap_ls();
client_->get_connection().send_cap_ls();
slot_ = client_->get_connection().sig_ircmsg.connect(
[self = shared_from_this()](const auto cmd, auto &msg)
{
self->on_ircmsg(cmd, msg);
}
);
if (not settings_.password.empty())
{
connection_.send_pass(settings_.password);
client_->get_connection().send_pass(settings_.password);
}
connection_.send_user(settings_.username, settings_.realname);
connection_.send_nick(settings_.nickname);
client_->get_connection().send_user(settings_.username, settings_.realname);
client_->get_connection().send_nick(settings_.nickname);
}
auto Registration::send_req() -> void
@ -69,13 +73,11 @@ auto Registration::send_req() -> void
if (not outstanding.empty())
{
request.pop_back();
connection_.send_cap_req(request);
listen_for_cap_ack();
client_->get_connection().send_cap_req(request);
}
else
{
connection_.send_cap_end();
client_->get_connection().send_cap_end();
}
}
@ -94,23 +96,11 @@ auto Registration::on_msg_cap_ack(const IrcMsg &msg) -> void
if (settings_.sasl_mechanism.empty())
{
slot_.disconnect();
connection_.send_cap_end();
client_->get_connection().send_cap_end();
}
else
{
self_->start_sasl(std::make_unique<SaslPlain>(settings_.sasl_authcid, settings_.sasl_authzid, settings_.sasl_password));
slot_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](auto cmd, auto &msg) {
switch (cmd)
{
default:
break;
case IrcCommand::RPL_SASLSUCCESS:
case IrcCommand::ERR_SASLFAIL:
thread->connection_.send_cap_end();
thread->slot_.disconnect();
}
});
client_->start_sasl(std::make_unique<SaslPlain>(settings_.sasl_authcid, settings_.sasl_authzid, settings_.sasl_password));
}
}
}
@ -155,20 +145,18 @@ auto Registration::on_msg_cap_ls(const IrcMsg &msg) -> void
if (last)
{
slot_.disconnect();
send_req();
}
}
auto Registration::start(
Connection &connection,
const Settings &settings,
std::shared_ptr<Client> self
std::shared_ptr<Client> client
) -> std::shared_ptr<Registration>
{
const auto thread = std::make_shared<Registration>(connection, std::move(settings), std::move(self));
const auto thread = std::make_shared<Registration>(std::move(settings), std::move(client));
thread->slot_ = connection.sig_connect.connect([thread]() {
thread->slot_ = thread->client_->get_connection().sig_connect.connect([thread]() {
thread->slot_.disconnect();
thread->on_connect();
});
@ -176,27 +164,48 @@ auto Registration::start(
return thread;
}
auto Registration::listen_for_cap_ack() -> void
auto Registration::randomize_nick() -> void
{
slot_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](IrcCommand cmd, const IrcMsg &msg) {
if (IrcCommand::CAP == cmd && msg.args.size() >= 2 && "ACK" == msg.args[1])
{
thread->on_msg_cap_ack(msg);
}
});
std::string new_nick;
new_nick += settings_.nickname.substr(0, 8);
std::random_device rd;
std::mt19937 gen{rd()};
std::uniform_int_distribution<> distrib(0, 35);
for (int i = 0; i < 8; ++i) {
const auto x = distrib(gen);
new_nick += x < 10 ? '0' + x : 'A' + (x-10);
}
auto Registration::listen_for_cap_ls() -> void
{
slot_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](IrcCommand cmd, const IrcMsg &msg) {
if (IrcCommand::CAP == cmd && msg.args.size() >= 2 && "LS" == msg.args[1])
{
thread->on_msg_cap_ls(msg);
client_->get_connection().send_nick(new_nick);
}
else if (IrcCommand::RPL_WELCOME == cmd)
auto Registration::on_ircmsg(const IrcCommand cmd, const IrcMsg &msg) -> void
{
// Server doesn't support CAP negotiation
thread->slot_.disconnect();
switch (cmd)
{
default: break;
case IrcCommand::CAP:
if (msg.args.size() >= 2 && "LS" == msg.args[1]) {
on_msg_cap_ls(msg);
} else if (msg.args.size() >= 2 && "ACK" == msg.args[1]) {
on_msg_cap_ack(msg);
}
break;
case IrcCommand::ERR_NICKNAMEINUSE:
randomize_nick();
break;
case IrcCommand::RPL_WELCOME:
slot_.disconnect();
break;
case IrcCommand::RPL_SASLSUCCESS:
case IrcCommand::ERR_SASLFAIL:
client_->get_connection().send_cap_end();
break;
}
});
}

View File

@ -11,9 +11,8 @@
class Registration : public std::enable_shared_from_this<Registration>
{
Connection &connection_;
const Settings &settings_;
std::shared_ptr<Client> self_;
std::shared_ptr<Client> client_;
std::unordered_map<std::string, std::string> caps;
std::unordered_set<std::string> outstanding;
@ -21,22 +20,21 @@ class Registration : public std::enable_shared_from_this<Registration>
boost::signals2::scoped_connection slot_;
auto on_connect() -> void;
auto send_req() -> void;
auto on_msg_cap_ls(const IrcMsg &msg) -> void;
auto on_msg_cap_ack(const IrcMsg &msg) -> void;
auto listen_for_cap_ack() -> void;
auto listen_for_cap_ls() -> void;
auto on_ircmsg(IrcCommand, const IrcMsg &msg) -> void;
auto send_req() -> void;
auto randomize_nick() -> void;
public:
Registration(
Connection &,
const Settings &,
std::shared_ptr<Client>
);
static auto start(
Connection &,
const Settings &,
std::shared_ptr<Client>
) -> std::shared_ptr<Registration>;