diff --git a/CMakeLists.txt b/CMakeLists.txt index cfaccba..c93b300 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,12 +43,14 @@ add_executable(xbot main.cpp irc_commands.inc bot.cpp + challenge.cpp + client.cpp connection.cpp irc_coroutine.cpp ircmsg.cpp + openssl_errors.cpp registration.cpp sasl_mechanism.cpp - client.cpp settings.cpp snote.cpp ) diff --git a/connection.cpp b/connection.cpp index 3fe36e7..1ab3cd1 100644 --- a/connection.cpp +++ b/connection.cpp @@ -224,13 +224,19 @@ auto Connection::send_join(std::string_view channel) -> void write_irc("JOIN", channel); } -auto Connection::send_whois(std::string_view arg1, std::string_view arg2) -> void +auto Connection::send_challenge(std::string_view message) -> void { - if (arg2.empty()) { - write_irc("WHOIS", arg1); - } else { - write_irc("WHOIS", arg1, arg2); - } + write_irc("CHALLENGE", message); +} + +auto Connection::send_whois(std::string_view arg1) -> void +{ + write_irc("WHOIS", arg1); +} + +auto Connection::send_whois_remote(std::string_view arg1, std::string_view arg2) -> void +{ + write_irc("WHOIS", arg1, arg2); } auto Connection::on_authenticate(const std::string_view chunk) -> void diff --git a/connection.hpp b/connection.hpp index cb022fb..a297a75 100644 --- a/connection.hpp +++ b/connection.hpp @@ -2,6 +2,7 @@ #include "irc_command.hpp" #include "ircmsg.hpp" +#include "ref.hpp" #include "snote.hpp" #include "stream.hpp" @@ -12,21 +13,8 @@ #include #include -template -class Ref { - struct Deleter { auto operator()(auto ptr) { Free(ptr); }}; - std::unique_ptr obj; -public: - Ref() = default; - Ref(T* t) : obj{t} { if (t) UpRef(t); } - auto get() const -> T* { return obj.get(); } -}; - struct ConnectSettings { - using X509_Ref = Ref; - using EVP_PKEY_Ref = Ref; - bool tls; std::string host; std::uint16_t port; @@ -107,7 +95,9 @@ public: auto send_authenticate(std::string_view message) -> void; auto send_authenticate_encoded(std::string_view message) -> void; auto send_authenticate_abort() -> void; - auto send_whois(std::string_view, std::string_view = {}) -> void; + auto send_whois(std::string_view) -> void; + auto send_whois_remote(std::string_view, std::string_view) -> void; + auto send_challenge(std::string_view) -> void; }; diff --git a/main.cpp b/main.cpp index 682f9a5..ad4d7c5 100644 --- a/main.cpp +++ b/main.cpp @@ -1,14 +1,15 @@ #include "bot.hpp" #include "c_callback.hpp" +#include "challenge.hpp" #include "client.hpp" #include "connection.hpp" +#include "openssl_errors.hpp" #include "registration.hpp" #include "settings.hpp" #include #include -#include #include #include @@ -18,21 +19,12 @@ using namespace std::literals; -static auto log_openssl_errors(const std::string_view prefix) -> void +static auto cert_from_file(const std::string &filename) -> X509_Ref { - auto err_cb = [prefix](const char *str, size_t len) -> int { - BOOST_LOG_TRIVIAL(error) << prefix << std::string_view{str, len}; - return 0; - }; - ERR_print_errors_cb(CCallback::invoke, &err_cb); -} - -static auto cert_from_file(const std::string &filename) -> ConnectSettings::X509_Ref -{ - ConnectSettings::X509_Ref cert; + X509_Ref cert; if (const auto fp = fopen(filename.c_str(), "r")) { - cert = PEM_read_X509(fp, nullptr, nullptr, nullptr); + cert.reset(PEM_read_X509(fp, nullptr, nullptr, nullptr)); if (cert.get() == nullptr) { log_openssl_errors("Reading certificate: "sv); @@ -47,12 +39,18 @@ static auto cert_from_file(const std::string &filename) -> ConnectSettings::X509 return cert; } -static auto key_from_file(const std::string &filename) -> ConnectSettings::EVP_PKEY_Ref +static auto key_from_file(const std::string &filename, const std::string_view password) -> EVP_PKEY_Ref { - ConnectSettings::EVP_PKEY_Ref key; + EVP_PKEY_Ref key; if (const auto fp = fopen(filename.c_str(), "r")) { - key = PEM_read_PrivateKey(fp, nullptr, nullptr, nullptr); + auto cb = [password](char * const buf, int const size, int) -> int { + if (size < password.size()) { return -1; } + std::copy(password.begin(), password.end(), buf); + return password.size(); + }; + + key.reset(PEM_read_PrivateKey(fp, nullptr, CCallback::invoke, &cb)); if (key.get() == nullptr) { log_openssl_errors("Reading private key: "sv); @@ -69,16 +67,16 @@ static auto key_from_file(const std::string &filename) -> ConnectSettings::EVP_P static auto start(boost::asio::io_context &io, const Settings &settings) -> void { - ConnectSettings::X509_Ref cert; + X509_Ref cert; if (settings.use_tls && not settings.tls_certfile.empty()) { cert = cert_from_file(settings.tls_certfile); } - ConnectSettings::EVP_PKEY_Ref key; + EVP_PKEY_Ref key; if (settings.use_tls && not settings.tls_keyfile.empty()) { - key = key_from_file(settings.tls_keyfile); + key = key_from_file(settings.tls_keyfile, ""); } const auto connection = std::make_shared(io); @@ -96,9 +94,15 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void } }); */ - client->sig_registered.connect([connection, client]() { + client->sig_registered.connect([&settings, connection, client]() { connection->send_join("##glguy"sv); connection->send_whois(client->get_my_nick()); + + if (not settings.challenge_username.empty() && + not settings.challenge_key_file.empty()) { + auto key = key_from_file(settings.challenge_key_file, settings.challenge_key_password); + Challenge::start(*connection, settings.challenge_username, std::move(key)); + } }); connection->sig_disconnect.connect( diff --git a/openssl_errors.cpp b/openssl_errors.cpp new file mode 100644 index 0000000..ca8d898 --- /dev/null +++ b/openssl_errors.cpp @@ -0,0 +1,14 @@ +#include "openssl_errors.hpp" + +#include "c_callback.hpp" + +#include + +auto log_openssl_errors(const std::string_view prefix) -> void +{ + auto err_cb = [prefix](const char *str, size_t len) -> int { + BOOST_LOG_TRIVIAL(error) << prefix << std::string_view{str, len}; + return 0; + }; + ERR_print_errors_cb(CCallback::invoke, &err_cb); +} diff --git a/settings.cpp b/settings.cpp index dd879ed..226e958 100644 --- a/settings.cpp +++ b/settings.cpp @@ -20,6 +20,9 @@ auto Settings::from_stream(std::istream &in) -> Settings .tls_hostname = config["tls_hostname"].value_or(std::string{}), .tls_certfile = config["tls_certfile"].value_or(std::string{}), .tls_keyfile = config["tls_keyfile"].value_or(std::string{}), + . challenge_username = config["challenge_username"].value_or(std::string{}), + . challenge_key_file = config["challenge_key_file"].value_or(std::string{}), + . challenge_key_password = config["challenge_key_password"].value_or(std::string{}), .use_tls = config["use_tls"].value_or(false), }; } diff --git a/settings.hpp b/settings.hpp index 5b57abc..cfeaa77 100644 --- a/settings.hpp +++ b/settings.hpp @@ -21,6 +21,10 @@ struct Settings std::string tls_certfile; std::string tls_keyfile; + std::string challenge_username; + std::string challenge_key_file; + std::string challenge_key_password; + bool use_tls; static auto from_stream(std::istream &in) -> Settings;