use shared-ptr more consistently

This commit is contained in:
Eric Mertens 2025-01-31 08:38:14 -08:00
parent eb01b304e3
commit 7728bc6aee
8 changed files with 64 additions and 64 deletions

View File

@ -23,23 +23,26 @@ using namespace std::literals;
static auto start(boost::asio::io_context &io, const Settings &settings) -> void static auto start(boost::asio::io_context &io, const Settings &settings) -> void
{ {
Ref<X509> cert; Ref<X509> tls_cert;
if (settings.use_tls && not settings.tls_certfile.empty()) if (settings.use_tls && not settings.tls_cert_file.empty())
{ {
cert = cert_from_file(settings.tls_certfile); tls_cert = cert_from_file(settings.tls_cert_file);
} }
Ref<EVP_PKEY> key; Ref<EVP_PKEY> tls_key;
if (settings.use_tls && not settings.tls_keyfile.empty()) if (settings.use_tls && not settings.tls_key_file.empty())
{ {
key = key_from_file(settings.tls_keyfile, ""); tls_key = key_from_file(settings.tls_key_file, settings.tls_key_password);
}
Ref<EVP_PKEY> sasl_key;
if (not settings.sasl_key_file.empty())
{
sasl_key = key_from_file(settings.sasl_key_file, settings.sasl_key_password);
} }
const auto connection = std::make_shared<Connection>(io); const auto connection = std::make_shared<Connection>(io);
const auto client = Client::start(*connection); const auto client = Client::start(connection);
Ref<EVP_PKEY> sasl_key;
if (not settings.sasl_key_file.empty())
sasl_key = key_from_file(settings.sasl_key_file, settings.sasl_key_password);
Registration::start({ Registration::start({
.nickname = settings.nickname, .nickname = settings.nickname,
.realname = settings.realname, .realname = settings.realname,
@ -51,32 +54,33 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void
.sasl_password = settings.sasl_password, .sasl_password = settings.sasl_password,
.sasl_key = std::move(sasl_key), .sasl_key = std::move(sasl_key),
}, client); }, client);
const auto bot = Bot::start(client); const auto bot = Bot::start(client);
// Configure CHALLENGE on registration if applicable
if (not settings.challenge_username.empty() && not settings.challenge_key_file.empty()) { if (not settings.challenge_username.empty() && not settings.challenge_key_file.empty()) {
if (auto key = key_from_file(settings.challenge_key_file, settings.challenge_key_password)) { if (auto key = key_from_file(settings.challenge_key_file, settings.challenge_key_password)) {
client->sig_registered.connect([&settings, connection, key = std::move(key)]() { client->sig_registered.connect([&settings, connection, key = std::move(key)]() {
Challenge::start(*connection, settings.challenge_username, key); Challenge::start(connection, settings.challenge_username, key);
}); });
} }
} }
// On disconnect tear down the various layers and reconnect in 5 seconds
// connection is captured in the disconnect handler so it can keep itself alive
connection->sig_disconnect.connect( connection->sig_disconnect.connect(
[&io, &settings, client, bot]() { [&io, &settings, connection]() {
client->shutdown();
bot->shutdown();
auto timer = std::make_shared<boost::asio::steady_timer>(io); auto timer = std::make_shared<boost::asio::steady_timer>(io);
timer->expires_after(5s); timer->expires_after(5s);
timer->async_wait([&io, &settings, timer](auto) { start(io, settings); }); timer->async_wait([&io, &settings, timer](auto) { start(io, settings); });
} }
); );
bot->sig_command.connect([connection](auto &cmd) { // Simple example of a command handler
std::cout << "COMMAND " << cmd.command << " from " << cmd.account << std::endl; bot->sig_command.connect([connection](const Bot::Command &cmd) {
if (cmd.oper == "glguy" && cmd.command == "ping") { if (cmd.oper == "glguy" && cmd.command == "ping") {
connection->send_notice("glguy", cmd.arguments); if (auto bang = cmd.source.find('!'); bang != cmd.source.npos) {
connection->send_notice(cmd.source.substr(0, bang), cmd.arguments);
}
} }
}); });
@ -85,12 +89,12 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void
.host = settings.host, .host = settings.host,
.port = settings.service, .port = settings.service,
.verify = settings.tls_hostname, .verify = settings.tls_hostname,
.client_cert = std::move(cert), .client_cert = std::move(tls_cert),
.client_key = std::move(key), .client_key = std::move(tls_key),
}); });
} }
static auto get_settings(const char *filename) -> Settings static auto get_settings(const char * const filename) -> Settings
{ {
if (auto config_stream = std::ifstream{filename}) if (auto config_stream = std::ifstream{filename})
{ {

View File

@ -20,8 +20,9 @@ auto Settings::from_stream(std::istream &in) -> Settings
.sasl_key_file = config["sasl_key_file"].value_or(std::string{}), .sasl_key_file = config["sasl_key_file"].value_or(std::string{}),
.sasl_key_password = config["sasl_key_password"].value_or(std::string{}), .sasl_key_password = config["sasl_key_password"].value_or(std::string{}),
.tls_hostname = config["tls_hostname"].value_or(std::string{}), .tls_hostname = config["tls_hostname"].value_or(std::string{}),
.tls_certfile = config["tls_certfile"].value_or(std::string{}), .tls_cert_file = config["tls_cert_file"].value_or(std::string{}),
.tls_keyfile = config["tls_keyfile"].value_or(std::string{}), .tls_key_file = config["tls_key_file"].value_or(std::string{}),
.tls_key_password = config["tls_key_password"].value_or(std::string{}),
.challenge_username = config["challenge_username"].value_or(std::string{}), .challenge_username = config["challenge_username"].value_or(std::string{}),
.challenge_key_file = config["challenge_key_file"].value_or(std::string{}), .challenge_key_file = config["challenge_key_file"].value_or(std::string{}),
.challenge_key_password = config["challenge_key_password"].value_or(std::string{}), .challenge_key_password = config["challenge_key_password"].value_or(std::string{}),

View File

@ -20,8 +20,9 @@ struct Settings
std::string sasl_key_password; std::string sasl_key_password;
std::string tls_hostname; std::string tls_hostname;
std::string tls_certfile; std::string tls_cert_file;
std::string tls_keyfile; std::string tls_key_file;
std::string tls_key_password;
std::string challenge_username; std::string challenge_username;
std::string challenge_key_file; std::string challenge_key_file;

View File

@ -12,9 +12,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
Challenge::Challenge(Ref<EVP_PKEY> key, Connection & connection) Challenge::Challenge(Ref<EVP_PKEY> key, std::shared_ptr<Connection> connection)
: key_{std::move(key)} : key_{std::move(key)}
, connection_{connection} , connection_{std::move(connection)}
{} {}
auto Challenge::on_ircmsg(IrcCommand cmd, const IrcMsg &msg) -> void { auto Challenge::on_ircmsg(IrcCommand cmd, const IrcMsg &msg) -> void {
@ -29,7 +29,7 @@ auto Challenge::on_ircmsg(IrcCommand cmd, const IrcMsg &msg) -> void {
break; break;
case IrcCommand::RPL_YOUREOPER: case IrcCommand::RPL_YOUREOPER:
slot_.disconnect(); slot_.disconnect();
connection_.send_ping("mitigation"); connection_->send_ping("mitigation");
break; break;
case IrcCommand::RPL_ENDOFRSACHALLENGE2: case IrcCommand::RPL_ENDOFRSACHALLENGE2:
finish_challenge(); finish_challenge();
@ -78,14 +78,14 @@ auto Challenge::finish_challenge() -> void
buffer_[0] = '+'; buffer_[0] = '+';
mybase64::encode(std::string_view{(char*)digest, digestlen}, buffer_.data() + 1); mybase64::encode(std::string_view{(char*)digest, digestlen}, buffer_.data() + 1);
connection_.send_challenge(buffer_); connection_->send_challenge(buffer_);
buffer_.clear(); buffer_.clear();
} }
auto Challenge::start(Connection &connection, const std::string_view user, Ref<EVP_PKEY> ref) -> std::shared_ptr<Challenge> auto Challenge::start(std::shared_ptr<Connection> connection, const std::string_view user, Ref<EVP_PKEY> ref) -> std::shared_ptr<Challenge>
{ {
auto self = std::make_shared<Challenge>(std::move(ref), connection); auto self = std::make_shared<Challenge>(std::move(ref), connection);
self->slot_ = connection.sig_ircmsg.connect([self](auto cmd, auto &msg) { self->on_ircmsg(cmd, msg); }); self->slot_ = connection->sig_ircmsg.connect([self](auto cmd, auto &msg) { self->on_ircmsg(cmd, msg); });
connection.send_challenge(user); connection->send_challenge(user);
return self; return self;
} }

View File

@ -137,11 +137,11 @@ auto Client::on_chat(bool notice, const IrcMsg &irc) -> void
}); });
} }
auto Client::start(Connection &connection) -> std::shared_ptr<Client> auto Client::start(std::shared_ptr<Connection> connection) -> std::shared_ptr<Client>
{ {
auto thread = std::make_shared<Client>(connection); auto thread = std::make_shared<Client>(connection);
connection.sig_ircmsg.connect([thread](auto cmd, auto &msg) { connection->sig_ircmsg.connect([thread](auto cmd, auto &msg) {
switch (cmd) switch (cmd)
{ {
case IrcCommand::PRIVMSG: case IrcCommand::PRIVMSG:
@ -186,7 +186,7 @@ auto Client::start(Connection &connection) -> std::shared_ptr<Client>
} }
}); });
connection.sig_authenticate.connect([thread](auto msg) { connection->sig_authenticate.connect([thread](auto msg) {
thread->on_authenticate(msg); thread->on_authenticate(msg);
}); });
@ -239,20 +239,20 @@ auto Client::on_authenticate(const std::string_view body) -> void
if (not sasl_mechanism_) if (not sasl_mechanism_)
{ {
BOOST_LOG_TRIVIAL(warning) << "Unexpected AUTHENTICATE from server"sv; BOOST_LOG_TRIVIAL(warning) << "Unexpected AUTHENTICATE from server"sv;
connection_.send_authenticate_abort(); connection_->send_authenticate_abort();
return; return;
} }
std::visit( std::visit(
overloaded{ overloaded{
[this](const std::string &reply) { [this](const std::string &reply) {
connection_.send_authenticate_encoded(reply); connection_->send_authenticate_encoded(reply);
}, },
[this](SaslMechanism::NoReply) { [this](SaslMechanism::NoReply) {
connection_.send_authenticate("*"sv); connection_->send_authenticate("*"sv);
}, },
[this](SaslMechanism::Failure) { [this](SaslMechanism::Failure) {
connection_.send_authenticate_abort(); connection_->send_authenticate_abort();
}, },
}, },
sasl_mechanism_->step(body)); sasl_mechanism_->step(body));
@ -267,11 +267,11 @@ auto Client::start_sasl(std::unique_ptr<SaslMechanism> mechanism) -> void
{ {
if (sasl_mechanism_) if (sasl_mechanism_)
{ {
connection_.send_authenticate("*"sv); // abort SASL connection_->send_authenticate("*"sv); // abort SASL
} }
sasl_mechanism_ = std::move(mechanism); sasl_mechanism_ = std::move(mechanism);
connection_.send_authenticate(sasl_mechanism_->mechanism_name()); connection_->send_authenticate(sasl_mechanism_->mechanism_name());
} }
static auto tolower_rfc1459(int c) -> int static auto tolower_rfc1459(int c) -> int
@ -328,16 +328,10 @@ auto Client::casemap_compare(std::string_view lhs, std::string_view rhs) const -
} }
} }
auto Client::shutdown() -> void
{
sig_registered.disconnect_all_slots();
sig_cap_ls.disconnect_all_slots();
}
auto Client::list_caps() -> void auto Client::list_caps() -> void
{ {
caps_available_.clear(); caps_available_.clear();
connection_.send_cap_ls(); connection_->send_cap_ls();
} }
auto Client::on_cap(const IrcMsg &msg) -> void auto Client::on_cap(const IrcMsg &msg) -> void

View File

@ -556,13 +556,13 @@ auto Connection::start(Settings settings) -> void
BOOST_LOG_TRIVIAL(debug) << "TERMINATED: " << e.what(); BOOST_LOG_TRIVIAL(debug) << "TERMINATED: " << e.what();
} }
self->sig_disconnect();
// Disconnect all slots to avoid circular references // Disconnect all slots to avoid circular references
self->sig_connect.disconnect_all_slots(); self->sig_connect.disconnect_all_slots();
self->sig_ircmsg.disconnect_all_slots(); self->sig_ircmsg.disconnect_all_slots();
self->sig_disconnect.disconnect_all_slots();
self->sig_snote.disconnect_all_slots(); self->sig_snote.disconnect_all_slots();
self->sig_authenticate.disconnect_all_slots(); self->sig_authenticate.disconnect_all_slots();
self->sig_disconnect();
self->sig_disconnect.disconnect_all_slots();
}); });
} }

View File

@ -12,7 +12,7 @@
class Challenge : std::enable_shared_from_this<Challenge> class Challenge : std::enable_shared_from_this<Challenge>
{ {
Ref<EVP_PKEY> key_; Ref<EVP_PKEY> key_;
Connection &connection_; std::shared_ptr<Connection> connection_;
boost::signals2::scoped_connection slot_; boost::signals2::scoped_connection slot_;
std::string buffer_; std::string buffer_;
@ -20,12 +20,12 @@ class Challenge : std::enable_shared_from_this<Challenge>
auto finish_challenge() -> void; auto finish_challenge() -> void;
public: public:
Challenge(Ref<EVP_PKEY>, Connection &); Challenge(Ref<EVP_PKEY>, std::shared_ptr<Connection>);
/// @brief Starts the CHALLENGE protocol. /// @brief Starts the CHALLENGE protocol.
/// @param connection Registered connection. /// @param connection Registered connection.
/// @param user Operator username /// @param user Operator username
/// @param key Operator private RSA key /// @param key Operator private RSA key
/// @return Handle to the challenge object. /// @return Handle to the challenge object.
static auto start(Connection &, std::string_view user, Ref<EVP_PKEY> key) -> std::shared_ptr<Challenge>; static auto start(std::shared_ptr<Connection>, std::string_view user, Ref<EVP_PKEY> key) -> std::shared_ptr<Challenge>;
}; };

View File

@ -32,7 +32,7 @@ struct Chat {
*/ */
class Client class Client
{ {
Connection &connection_; std::shared_ptr<Connection> connection_;
std::string nickname_; std::string nickname_;
std::string mode_; std::string mode_;
@ -67,23 +67,23 @@ public:
boost::signals2::signal<void(const std::unordered_map<std::string, std::string> &)> sig_cap_ls; boost::signals2::signal<void(const std::unordered_map<std::string, std::string> &)> sig_cap_ls;
boost::signals2::signal<void(const Chat &)> sig_chat; boost::signals2::signal<void(const Chat &)> sig_chat;
Client(Connection &connection) Client(std::shared_ptr<Connection> connection)
: connection_{connection} : connection_{std::move(connection)}
, casemap_{Casemap::Rfc1459} , casemap_{Casemap::Rfc1459}
, channel_prefix_{"#&"} , channel_prefix_{"#&"}
, status_msg_{"+@"} , status_msg_{"+@"}
{ {
} }
auto get_connection() -> Connection & { return connection_; } auto get_connection() -> Connection & { return *connection_; }
static auto start(Connection &) -> std::shared_ptr<Client>; static auto start(std::shared_ptr<Connection>) -> std::shared_ptr<Client>;
auto start_sasl(std::unique_ptr<SaslMechanism> mechanism) -> void; auto start_sasl(std::unique_ptr<SaslMechanism> mechanism) -> void;
auto get_connection() const -> std::shared_ptr<Connection> auto get_connection() const -> std::shared_ptr<Connection>
{ {
return connection_.shared_from_this(); return connection_->shared_from_this();
} }
auto get_my_nick() const -> const std::string &; auto get_my_nick() const -> const std::string &;