diff --git a/CMakeLists.txt b/CMakeLists.txt index c93b300..9c1e781 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,6 @@ add_executable(xbot challenge.cpp client.cpp connection.cpp - irc_coroutine.cpp ircmsg.cpp openssl_errors.cpp registration.cpp diff --git a/challenge.cpp b/challenge.cpp index 9de6a64..161edda 100644 --- a/challenge.cpp +++ b/challenge.cpp @@ -39,9 +39,11 @@ auto Challenge::on_ircmsg(IrcCommand cmd, const IrcMsg &msg) -> void { break; case IrcCommand::RPL_ENDOFRSACHALLENGE2: { + slot_.disconnect(); + EVP_PKEY_CTX_Ref ctx; - unsigned int digestlen = 160; - unsigned char digest[160]; + unsigned int digestlen = EVP_MAX_MD_SIZE; + unsigned char digest[EVP_MAX_MD_SIZE]; size_t len = mybase64::decoded_size(buffer_.size()); std::vector ciphertext(len, 0); @@ -52,14 +54,15 @@ auto Challenge::on_ircmsg(IrcCommand cmd, const IrcMsg &msg) -> void { ctx.reset(EVP_PKEY_CTX_new(key_.get(), nullptr)); if (ctx.get() == nullptr) goto error; if (1 != EVP_PKEY_decrypt_init(ctx.get())) goto error; - if (0 <= EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING)) goto error; + if (0 >= EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING)) goto error; // Determine output size if (1 != EVP_PKEY_decrypt(ctx.get(), nullptr, &len, ciphertext.data(), ciphertext.size())) goto error; buffer_.resize(len); // Decrypt ciphertext - EVP_PKEY_decrypt(ctx.get(), reinterpret_cast(buffer_.data()), &len, ciphertext.data(), ciphertext.size()); + if (1 != EVP_PKEY_decrypt(ctx.get(), reinterpret_cast(buffer_.data()), &len, ciphertext.data(), ciphertext.size())) goto error; + buffer_.resize(len); // Hash the decrypted message if (1 != EVP_Digest(buffer_.data(), buffer_.size(), digest, &digestlen, EVP_sha1(), nullptr)) goto error; @@ -70,13 +73,10 @@ auto Challenge::on_ircmsg(IrcCommand cmd, const IrcMsg &msg) -> void { mybase64::encode(std::string_view{(char*)digest, digestlen}, buffer_.data() + 1); connection_.send_challenge(buffer_); - - stop(); return; error: log_openssl_errors("Challenge: "); - stop(); } } } diff --git a/irc_coroutine.cpp b/irc_coroutine.cpp deleted file mode 100644 index f2308c9..0000000 --- a/irc_coroutine.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "irc_coroutine.hpp" - -auto irc_coroutine::is_running() -> bool -{ - return promise().connection_ != nullptr; -} -auto irc_coroutine::exception() -> std::exception_ptr -{ - return promise().exception_; -} - -auto irc_coroutine::start(Connection &connection) -> void -{ - promise().connection_ = connection.shared_from_this(); - resume(); -} - -void wait_ircmsg::stop() { ircmsg_slot_.disconnect(); } - -void wait_timeout::stop() { timer_.reset(); } diff --git a/irc_coroutine.hpp b/irc_coroutine.hpp index f5fd617..30b16fd 100644 --- a/irc_coroutine.hpp +++ b/irc_coroutine.hpp @@ -80,7 +80,28 @@ public: template auto start(Wait &command) -> void; - auto stop() -> void; + auto stop() -> void { ircmsg_slot_.disconnect(); } +}; + +class wait_snote +{ + // Vector of tags this wait is expecting. Leave empty to accept all messages. + std::vector want_tags_; + + // Slot for the snote event + boost::signals2::scoped_connection snote_slot_; + +public: + using result_type = SnoteMatch; + + wait_snote(std::initializer_list want_tags) + : want_tags_{want_tags} + { + } + + template + auto start(Wait &command) -> void; + auto stop() -> void { snote_slot_.disconnect(); } }; class wait_timeout @@ -99,7 +120,7 @@ public: template auto start(Wait &command) -> void; - auto stop() -> void; + auto stop() -> void { timer_->cancel(); } }; template @@ -180,6 +201,19 @@ auto wait_ircmsg::start(Wait &command) -> void }); } +template +auto wait_snote::start(Wait &command) -> void +{ + snote_slot_ = command.get_connection().sig_snote.connect([this, &command](auto &match) { + const auto tag = match.get_tag(); + const auto wanted = want_tags_.empty() || std::find(want_tags_.begin(), want_tags_.end(), tag) != want_tags_.end(); + if (wanted) + { + command.template complete(match); + } + }); +} + template auto wait_timeout::start(Wait &command) -> void { @@ -224,3 +258,21 @@ auto Wait::await_resume() -> std::variant throw std::runtime_error{"connection terminated"}; } } + +/// Start the coroutine and associate it with a specific connection. +inline auto irc_coroutine::start(Connection &connection) -> void +{ + promise().connection_ = connection.shared_from_this(); + resume(); +} + +/// Returns true when this coroutine is still waiting on events +inline auto irc_coroutine::is_running() -> bool +{ + return promise().connection_ != nullptr; +} + +inline auto irc_coroutine::exception() -> std::exception_ptr +{ + return promise().exception_; +} diff --git a/main.cpp b/main.cpp index ad4d7c5..76346de 100644 --- a/main.cpp +++ b/main.cpp @@ -6,6 +6,7 @@ #include "openssl_errors.hpp" #include "registration.hpp" #include "settings.hpp" +#include "irc_coroutine.hpp" #include #include @@ -130,22 +131,26 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void }); } -static auto get_settings() -> Settings +static auto get_settings(const char *filename) -> Settings { - if (auto config_stream = std::ifstream{"config.toml"}) + if (auto config_stream = std::ifstream{filename}) { return Settings::from_stream(config_stream); } else { - BOOST_LOG_TRIVIAL(error) << "Unable to open config.toml"; + BOOST_LOG_TRIVIAL(error) << "Unable to open configuration"; std::exit(1); } } -auto main() -> int +auto main(int argc, char *argv[]) -> int { - const auto settings = get_settings(); + if (argc != 2) { + BOOST_LOG_TRIVIAL(error) << "Bad arguments"; + return 1; + } + const auto settings = get_settings(argv[1]); auto io = boost::asio::io_context{}; start(io, settings); io.run(); diff --git a/registration.cpp b/registration.cpp index f927111..68695da 100644 --- a/registration.cpp +++ b/registration.cpp @@ -81,12 +81,14 @@ auto Registration::on_cap_list(const std::unordered_mapget_connection().send_cap_req(request); } - if (do_sasl) { + if (do_sasl && settings_.sasl_mechanism == "PLAIN") { client_->start_sasl( std::make_unique( - settings_.sasl_authcid, - settings_.sasl_authzid, + settings_.sasl_authcid, + settings_.sasl_authzid, settings_.sasl_password)); + } else if (do_sasl && settings_.sasl_mechanism == "EXTERNAL") { + client_->start_sasl(std::make_unique(settings_.sasl_authzid)); } else { client_->get_connection().send_cap_end(); } diff --git a/sasl_mechanism.cpp b/sasl_mechanism.cpp index f01b003..d7ff9e4 100644 --- a/sasl_mechanism.cpp +++ b/sasl_mechanism.cpp @@ -17,3 +17,11 @@ auto SaslPlain::step(std::string_view msg) -> std::optional { return {std::move(reply)}; } } + +auto SaslExternal::step(std::string_view msg) -> std::optional { + if (complete_) { + return std::nullopt; + } else { + return {std::move(authzid_)}; + } +} diff --git a/sasl_mechanism.hpp b/sasl_mechanism.hpp index 91eac7c..7e97841 100644 --- a/sasl_mechanism.hpp +++ b/sasl_mechanism.hpp @@ -44,3 +44,27 @@ public: return complete_; } }; + +class SaslExternal final : public SaslMechanism +{ + std::string authzid_; + bool complete_; + +public: + SaslExternal(std::string authzid) + : authzid_{std::move(authzid)} + , complete_{false} + {} + + auto mechanism_name() const -> std::string override + { + return "EXTERNAL"; + } + + auto step(std::string_view msg) -> std::optional override; + + auto is_complete() const -> bool override + { + return complete_; + } +}; diff --git a/snote.cpp b/snote.cpp index 196eb47..85e6ee9 100644 --- a/snote.cpp +++ b/snote.cpp @@ -74,6 +74,12 @@ const SnotePattern static patterns[] = { {SnoteTag::SetVhostOnMarkedAccount, "^\x02([^ ]+)\x02 set vhost ([^ ]+) on the \x02MARKED\x02 account ([^ ]+).$"}, + + {SnoteTag::IsNowOper, + R"(^([^ ]+) \(([^ ]+)!([^ ]+)@([^ ]+)\) is now an operator$)"}, + + {SnoteTag::NickCollision, + R"(^Nick collision due to services forced nick change on ([^ ]+)$)"}, }; static auto setup_database() -> hs_database_t * diff --git a/snote.hpp b/snote.hpp index 15519f4..2e8ead3 100644 --- a/snote.hpp +++ b/snote.hpp @@ -29,6 +29,8 @@ enum class SnoteTag Killed, TooManyGlobalConnections, SetVhostOnMarkedAccount, + IsNowOper, + NickCollision, }; class SnoteMatch