diff --git a/driver/main.cpp b/driver/main.cpp index 569dad8..bf5c0b7 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -5,6 +5,7 @@ #include "openssl_utils.hpp" #include "registration.hpp" #include "settings.hpp" +#include "ref.hpp" #include "irc_coroutine.hpp" #include @@ -36,6 +37,9 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void const auto connection = std::make_shared(io); const auto client = Client::start(*connection); + Ref sasl_key; + if (not settings.sasl_key_file.empty()) + sasl_key = key_from_file(settings.sasl_key_file, settings.sasl_key_password); Registration::start({ .nickname = settings.nickname, .realname = settings.realname, @@ -45,6 +49,7 @@ static auto start(boost::asio::io_context &io, const Settings &settings) -> void .sasl_authcid = settings.sasl_authcid, .sasl_authzid = settings.sasl_authzid, .sasl_password = settings.sasl_password, + .sasl_key = std::move(sasl_key), }, client); const auto bot = Bot::start(client); diff --git a/driver/settings.cpp b/driver/settings.cpp index 226e958..45914c5 100644 --- a/driver/settings.cpp +++ b/driver/settings.cpp @@ -17,6 +17,8 @@ auto Settings::from_stream(std::istream &in) -> Settings .sasl_authcid = config["sasl_authcid"].value_or(std::string{}), .sasl_authzid = config["sasl_authzid"].value_or(std::string{}), .sasl_password = config["sasl_password"].value_or(std::string{}), + .sasl_key_file = config["sasl_key_file"].value_or(std::string{}), + .sasl_key_password = config["sasl_key_password"].value_or(std::string{}), .tls_hostname = config["tls_hostname"].value_or(std::string{}), .tls_certfile = config["tls_certfile"].value_or(std::string{}), .tls_keyfile = config["tls_keyfile"].value_or(std::string{}), diff --git a/driver/settings.hpp b/driver/settings.hpp index cfeaa77..a021b59 100644 --- a/driver/settings.hpp +++ b/driver/settings.hpp @@ -16,6 +16,8 @@ struct Settings std::string sasl_authcid; std::string sasl_authzid; std::string sasl_password; + std::string sasl_key_file; + std::string sasl_key_password; std::string tls_hostname; std::string tls_certfile; diff --git a/myirc/include/ref.hpp b/myirc/include/ref.hpp index a4518e6..6128ce3 100644 --- a/myirc/include/ref.hpp +++ b/myirc/include/ref.hpp @@ -41,8 +41,10 @@ struct Ref : std::unique_ptr> explicit Ref(T *x) noexcept : base{x} {} Ref(Ref &&ref) noexcept = default; - Ref(const Ref &ref) noexcept { - *this = ref; + Ref(const Ref &ref) noexcept : base{ref.get()} { + if (*this) { + RefTraits::UpRef(this->get()); + } } Ref &operator=(Ref&&) noexcept = default; diff --git a/myirc/include/sasl_mechanism.hpp b/myirc/include/sasl_mechanism.hpp index ca693a3..2a66183 100644 --- a/myirc/include/sasl_mechanism.hpp +++ b/myirc/include/sasl_mechanism.hpp @@ -1,5 +1,7 @@ #pragma once +#include "ref.hpp" + #include #include @@ -74,3 +76,35 @@ public: return complete_; } }; + + +class SaslEcdsa final : public SaslMechanism +{ + std::string message1_; + Ref key_; + int stage_; + +public: + SaslEcdsa(std::string authcid, std::string authzid, Ref key) + : message1_{std::move(authcid)} + , key_{std::move(key)} + , stage_{0} + { + if (not authzid.empty()) { + message1_.push_back(0); + message1_.append(authzid); + } + } + + auto mechanism_name() const -> std::string override + { + return "ECDSA-NIST256P-CHALLENGE"; + } + + auto step(std::string_view msg) -> StepResult override; + + auto is_complete() const -> bool override + { + return stage_ == 2;; + } +}; diff --git a/myirc/registration.cpp b/myirc/registration.cpp index d116201..e2130d3 100644 --- a/myirc/registration.cpp +++ b/myirc/registration.cpp @@ -3,6 +3,7 @@ #include "connection.hpp" #include "ircmsg.hpp" #include "sasl_mechanism.hpp" +#include "openssl_utils.hpp" #include #include @@ -89,6 +90,11 @@ auto Registration::on_cap_list(const std::unordered_mapstart_sasl(std::make_unique(settings_.sasl_authzid)); + } else if (do_sasl && settings_.sasl_mechanism == "ECDSA") { + client_->start_sasl(std::make_unique( + settings_.sasl_authcid, + settings_.sasl_authzid, + settings_.sasl_key)); } else { client_->get_connection().send_cap_end(); } diff --git a/myirc/sasl_mechanism.cpp b/myirc/sasl_mechanism.cpp index 65b6429..71b65e9 100644 --- a/myirc/sasl_mechanism.cpp +++ b/myirc/sasl_mechanism.cpp @@ -1,4 +1,7 @@ #include "sasl_mechanism.hpp" +#include "openssl_utils.hpp" + +#include auto SaslPlain::step(std::string_view msg) -> StepResult { if (complete_) { @@ -13,7 +16,6 @@ auto SaslPlain::step(std::string_view msg) -> StepResult { reply += password_; complete_ = true; - return std::move(reply); } } @@ -22,6 +24,51 @@ auto SaslExternal::step(std::string_view msg) -> StepResult { if (complete_) { return Failure{}; } else { + complete_ = true; return std::move(authzid_); } } + +auto SaslEcdsa::step(std::string_view msg) -> StepResult { + switch (stage_) { + case 0: + stage_ = 1; + return std::move(message1_); + case 1: + { + stage_ = 2; + Ref ctx {EVP_PKEY_CTX_new(key_.get(), nullptr)}; + if (not ctx) { + log_openssl_errors("ECDSA new context: "); + return Failure{}; + } + + if (0 >= EVP_PKEY_sign_init(ctx.get())) + { + log_openssl_errors("ECDSA init: "); + return Failure{}; + } + + const auto input = reinterpret_cast(msg.data()); + size_t siglen; + if (0 >= EVP_PKEY_sign(ctx.get(), nullptr, &siglen, input, msg.size())) + { + log_openssl_errors("ECDSA signature (presize): "); + return Failure{}; + } + + std::string result(siglen, '\0'); + const auto output = reinterpret_cast(result.data()); + if (0 >= EVP_PKEY_sign(ctx.get(), output, &siglen, input, msg.size())) + { + log_openssl_errors("ECDSA signature: "); + return Failure{}; + } + result.resize(siglen); + + return std::move(result); + } + default: + return Failure{}; + } +}