more sasl

This commit is contained in:
Eric Mertens 2025-01-25 12:25:38 -08:00
parent 1a1deb03b7
commit 4d298b7eec
15 changed files with 269 additions and 207 deletions

View File

@ -34,6 +34,7 @@ add_executable(xbot
registration_thread.cpp registration_thread.cpp
snote.cpp snote.cpp
self_thread.cpp self_thread.cpp
sasl_mechanism.cpp
irc_coroutine.cpp irc_coroutine.cpp
) )

View File

@ -2,6 +2,8 @@
#include "linebuffer.hpp" #include "linebuffer.hpp"
#include <mybase64.hpp>
#include <boost/log/trivial.hpp> #include <boost/log/trivial.hpp>
namespace { namespace {
@ -56,7 +58,8 @@ auto Connection::connect(
{ {
auto resolver = boost::asio::ip::tcp::resolver{io}; auto resolver = boost::asio::ip::tcp::resolver{io};
auto const endpoints = co_await resolver.async_resolve(host, port, boost::asio::use_awaitable); auto const endpoints = co_await resolver.async_resolve(host, port, boost::asio::use_awaitable);
co_await boost::asio::async_connect(stream_, endpoints, boost::asio::use_awaitable); auto const endpoint = co_await boost::asio::async_connect(stream_, endpoints, boost::asio::use_awaitable);
BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << endpoint;
sig_connect(); sig_connect();
} }
@ -136,6 +139,10 @@ auto Connection::dispatch_line(char *line) -> void
BOOST_LOG_TRIVIAL(warning) << "Unrecognized command: " << msg.command << " " << msg.args.size(); BOOST_LOG_TRIVIAL(warning) << "Unrecognized command: " << msg.command << " " << msg.args.size();
break; break;
case IrcCommand::AUTHENTICATE:
on_authenticate(msg.args[0]);
break;
// Server notice generate snote events but not IRC command events // Server notice generate snote events but not IRC command events
case IrcCommand::NOTICE: case IrcCommand::NOTICE:
if (auto match = snoteCore.match(msg)) { if (auto match = snoteCore.match(msg)) {
@ -251,3 +258,56 @@ auto Connection::send_authenticate(std::string_view message) -> void
{ {
write_irc("AUTHENTICATE", message); write_irc("AUTHENTICATE", message);
} }
auto Connection::on_authenticate(const std::string_view chunk) -> void
{
if (chunk != "+"sv) {
authenticate_buffer_ += chunk;
}
if (chunk.size() != 400)
{
std::string decoded;
decoded.resize(mybase64::decoded_size(authenticate_buffer_.size()));
std::size_t len;
if (mybase64::decode(authenticate_buffer_, decoded.data(), &len))
{
decoded.resize(len);
sig_authenticate(decoded);
} else {
BOOST_LOG_TRIVIAL(debug) << "Invalid AUTHENTICATE base64"sv;
send_authenticate("*"sv); // abort SASL
}
authenticate_buffer_.clear();
}
else if (authenticate_buffer_.size() > 1024)
{
BOOST_LOG_TRIVIAL(debug) << "AUTHENTICATE buffer overflow"sv;
authenticate_buffer_.clear();
send_authenticate("*"sv); // abort SASL
}
}
auto Connection::send_authenticate_abort() -> void
{
send_authenticate("*");
}
auto Connection::send_authenticate_encoded(std::string_view body) -> void
{
std::string encoded(mybase64::encoded_size(body.size()), 0);
mybase64::encode(body, encoded.data());
for (size_t lo = 0; lo < encoded.size(); lo += 400) {
const auto hi = std::min(lo + 400, encoded.size());
const std::string_view chunk { encoded.begin() + lo, encoded.begin() + hi };
send_authenticate(chunk);
}
if (encoded.size() % 400 == 0)
{
send_authenticate("+"sv);
}
}

View File

@ -23,6 +23,9 @@ private:
// Set false when message received. // Set false when message received.
bool stalled_; bool stalled_;
// AUTHENTICATE support
std::string authenticate_buffer_;
auto write_buffers() -> void; auto write_buffers() -> void;
auto dispatch_line(char * line) -> void; auto dispatch_line(char * line) -> void;
@ -46,6 +49,7 @@ public:
boost::signals2::signal<void()> sig_disconnect; boost::signals2::signal<void()> sig_disconnect;
boost::signals2::signal<void(IrcCommand, const IrcMsg &)> sig_ircmsg; boost::signals2::signal<void(IrcCommand, const IrcMsg &)> sig_ircmsg;
boost::signals2::signal<void(SnoteMatch &)> sig_snote; boost::signals2::signal<void(SnoteMatch &)> sig_snote;
boost::signals2::signal<void(std::string_view)> sig_authenticate;
auto get_executor() -> boost::asio::any_io_executor { auto get_executor() -> boost::asio::any_io_executor {
return stream_.get_executor(); return stream_.get_executor();
@ -59,6 +63,8 @@ public:
auto close() -> void; auto close() -> void;
auto on_authenticate(std::string_view) -> void;
auto send_ping(std::string_view) -> void; auto send_ping(std::string_view) -> void;
auto send_pong(std::string_view) -> void; auto send_pong(std::string_view) -> void;
auto send_pass(std::string_view) -> void; auto send_pass(std::string_view) -> void;
@ -70,6 +76,8 @@ public:
auto send_privmsg(std::string_view, std::string_view) -> void; auto send_privmsg(std::string_view, std::string_view) -> void;
auto send_notice(std::string_view, std::string_view) -> void; auto send_notice(std::string_view, std::string_view) -> void;
auto send_authenticate(std::string_view message) -> void; auto send_authenticate(std::string_view message) -> void;
auto send_authenticate_encoded(std::string_view message) -> void;
auto send_authenticate_abort() -> void;
}; };
template <typename... Args> template <typename... Args>

View File

@ -8,7 +8,7 @@ struct RecognizedCommand {
001, IrcCommand::RPL_WELCOME, 2, 2 001, IrcCommand::RPL_WELCOME, 2, 2
002, IrcCommand::RPL_YOURHOST, 2, 2 002, IrcCommand::RPL_YOURHOST, 2, 2
003, IrcCommand::RPL_CREATED, 2, 2 003, IrcCommand::RPL_CREATED, 2, 2
004, IrcCommand::RPL_MYINFO, 5, 5 004, IrcCommand::RPL_MYINFO, 5, 6
005, IrcCommand::RPL_ISUPPORT, 2, 15 005, IrcCommand::RPL_ISUPPORT, 2, 15
008, IrcCommand::RPL_SNOMASK, 3, 3 008, IrcCommand::RPL_SNOMASK, 3, 3
010, IrcCommand::RPL_REDIR, 4, 4 010, IrcCommand::RPL_REDIR, 4, 4
@ -47,12 +47,12 @@ struct RecognizedCommand {
247, IrcCommand::RPL_STATSXLINE 247, IrcCommand::RPL_STATSXLINE
248, IrcCommand::RPL_STATSULINE 248, IrcCommand::RPL_STATSULINE
249, IrcCommand::RPL_STATSDEBUG 249, IrcCommand::RPL_STATSDEBUG
250, IrcCommand::RPL_STATSCONN 250, IrcCommand::RPL_STATSCONN, 2, 2
251, IrcCommand::RPL_LUSERCLIENT 251, IrcCommand::RPL_LUSERCLIENT, 2, 2
252, IrcCommand::RPL_LUSEROP 252, IrcCommand::RPL_LUSEROP
253, IrcCommand::RPL_LUSERUNKNOWN 253, IrcCommand::RPL_LUSERUNKNOWN
254, IrcCommand::RPL_LUSERCHANNELS 254, IrcCommand::RPL_LUSERCHANNELS
255, IrcCommand::RPL_LUSERME 255, IrcCommand::RPL_LUSERME, 2, 2
256, IrcCommand::RPL_ADMINME, 3, 3 256, IrcCommand::RPL_ADMINME, 3, 3
257, IrcCommand::RPL_ADMINLOC1, 2, 2 257, IrcCommand::RPL_ADMINLOC1, 2, 2
258, IrcCommand::RPL_ADMINLOC2, 2, 2 258, IrcCommand::RPL_ADMINLOC2, 2, 2
@ -248,15 +248,15 @@ struct RecognizedCommand {
744, IrcCommand::ERR_TOPICLOCK 744, IrcCommand::ERR_TOPICLOCK
750, IrcCommand::RPL_SCANMATCHED 750, IrcCommand::RPL_SCANMATCHED
751, IrcCommand::RPL_SCANUMODES 751, IrcCommand::RPL_SCANUMODES
900, IrcCommand::RPL_LOGGEDIN 900, IrcCommand::RPL_LOGGEDIN, 4, 4
901, IrcCommand::RPL_LOGGEDOUT 901, IrcCommand::RPL_LOGGEDOUT, 3, 3
902, IrcCommand::ERR_NICKLOCKED 902, IrcCommand::ERR_NICKLOCKED, 2, 2
903, IrcCommand::RPL_SASLSUCCESS 903, IrcCommand::RPL_SASLSUCCESS, 2, 2
904, IrcCommand::ERR_SASLFAIL 904, IrcCommand::ERR_SASLFAIL, 2, 2
905, IrcCommand::ERR_SASLTOOLONG 905, IrcCommand::ERR_SASLTOOLONG, 2, 2
906, IrcCommand::ERR_SASLABORTED 906, IrcCommand::ERR_SASLABORTED, 2, 2
907, IrcCommand::ERR_SASLALREADY 907, IrcCommand::ERR_SASLALREADY, 2, 2
908, IrcCommand::RPL_SASLMECHS 908, IrcCommand::RPL_SASLMECHS, 3, 3
ACCOUNT, IrcCommand::ACCOUNT, 1, 1 ACCOUNT, IrcCommand::ACCOUNT, 1, 1
AUTHENTICATE, IrcCommand::AUTHENTICATE, 1, 1 AUTHENTICATE, IrcCommand::AUTHENTICATE, 1, 1
AWAY, IrcCommand::AWAY, 0, 1 AWAY, IrcCommand::AWAY, 0, 1

View File

@ -17,8 +17,8 @@ auto start(boost::asio::io_context & io, Settings const& settings) -> void
{ {
auto const connection = std::make_shared<Connection>(io); auto const connection = std::make_shared<Connection>(io);
RegistrationThread::start(*connection, settings.password, settings.username, settings.realname, settings.nickname); auto const selfThread = SelfThread::start(*connection);
auto selfThread = SelfThread::start(*connection); RegistrationThread::start(*connection, settings, selfThread);
connection->sig_snote.connect([](auto &match) { connection->sig_snote.connect([](auto &match) {
std::cout << "SNOTE " << static_cast<int>(match.get_tag()) << std::endl; std::cout << "SNOTE " << static_cast<int>(match.get_tag()) << std::endl;

View File

@ -2,6 +2,7 @@
#include "connection.hpp" #include "connection.hpp"
#include "ircmsg.hpp" #include "ircmsg.hpp"
#include "sasl_mechanism.hpp"
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
@ -9,31 +10,27 @@
RegistrationThread::RegistrationThread( RegistrationThread::RegistrationThread(
Connection& connection, Connection& connection,
std::string password, const Settings &settings,
std::string username, std::shared_ptr<SelfThread> self
std::string realname,
std::string nickname
) )
: connection_{connection} : connection_{connection}
, password_{std::move(password)} , settings_{settings}
, username_{std::move(username)} , self_{std::move(self)}
, realname_{std::move(realname)}
, nickname_{std::move(nickname)}
{ {
} }
auto RegistrationThread::on_connect() -> void auto RegistrationThread::on_connect() -> void
{ {
connection_.send_cap_ls(); connection_.send_cap_ls();
connection_.send_pass(password_); connection_.send_pass(settings_.password);
connection_.send_user(username_, realname_); connection_.send_user(settings_.username, settings_.realname);
connection_.send_nick(nickname_); connection_.send_nick(settings_.nickname);
} }
auto RegistrationThread::send_req() -> void auto RegistrationThread::send_req() -> void
{ {
std::string request; std::string request;
char const* const want[] = { std::vector<char const*> want {
"account-notify", "account-notify",
"account-tag", "account-tag",
"batch", "batch",
@ -49,6 +46,10 @@ auto RegistrationThread::send_req() -> void
"solanum.chat/realhost", "solanum.chat/realhost",
}; };
if (settings_.sasl_mechanism == "PLAIN") {
want.push_back("sasl");
}
for (auto const cap : want) for (auto const cap : want)
{ {
if (caps.contains(cap)) if (caps.contains(cap))
@ -85,7 +86,21 @@ auto RegistrationThread::on_msg_cap_ack(IrcMsg const& msg) -> void
if (outstanding.empty()) if (outstanding.empty())
{ {
message_handle_.disconnect(); message_handle_.disconnect();
connection_.send_cap_end();
if (settings_.sasl_mechanism.empty()) {
connection_.send_cap_end();
} else {
self_->start_sasl(std::make_unique<SaslPlain>(settings_.sasl_authcid, settings_.sasl_authzid, settings_.sasl_password));
connection_.sig_ircmsg.connect_extended([thread = shared_from_this()](auto &slot, auto cmd, auto &msg) {
switch (cmd) {
default: break;
case IrcCommand::RPL_SASLSUCCESS:
case IrcCommand::ERR_SASLFAIL:
thread->connection_.send_cap_end();
slot.disconnect();
}
});
}
} }
} }
@ -136,13 +151,11 @@ auto RegistrationThread::on_msg_cap_ls(IrcMsg const& msg) -> void
auto RegistrationThread::start( auto RegistrationThread::start(
Connection& connection, Connection& connection,
std::string password, const Settings &settings,
std::string username, std::shared_ptr<SelfThread> self
std::string realname,
std::string nickname
) -> std::shared_ptr<RegistrationThread> ) -> std::shared_ptr<RegistrationThread>
{ {
auto const thread = std::make_shared<RegistrationThread>(connection, password, username, realname, nickname); auto const thread = std::make_shared<RegistrationThread>(connection, std::move(settings), std::move(self));
thread->listen_for_cap_ls(); thread->listen_for_cap_ls();
@ -159,7 +172,7 @@ auto RegistrationThread::listen_for_cap_ack() -> void
{ {
message_handle_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](IrcCommand cmd, IrcMsg const& msg) message_handle_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](IrcCommand cmd, IrcMsg const& msg)
{ {
if (IrcCommand::CAP == cmd && msg.args.size() >= 2 && "*" == msg.args[0] && "ACK" == msg.args[1]) if (IrcCommand::CAP == cmd && msg.args.size() >= 2 && "ACK" == msg.args[1])
{ {
thread->on_msg_cap_ack(msg); thread->on_msg_cap_ack(msg);
} }
@ -170,7 +183,7 @@ auto RegistrationThread::listen_for_cap_ls() -> void
{ {
message_handle_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](IrcCommand cmd, IrcMsg const& msg) message_handle_ = connection_.sig_ircmsg.connect([thread = shared_from_this()](IrcCommand cmd, IrcMsg const& msg)
{ {
if (IrcCommand::CAP == cmd && msg.args.size() >= 2 && "*" == msg.args[0] && "LS" == msg.args[1]) if (IrcCommand::CAP == cmd && msg.args.size() >= 2 && "LS" == msg.args[1])
{ {
thread->on_msg_cap_ls(msg); thread->on_msg_cap_ls(msg);
} }

View File

@ -1,6 +1,8 @@
#pragma once #pragma once
#include "connection.hpp" #include "connection.hpp"
#include "settings.hpp"
#include "self_thread.hpp"
#include <memory> #include <memory>
#include <string> #include <string>
@ -10,10 +12,8 @@
class RegistrationThread : public std::enable_shared_from_this<RegistrationThread> class RegistrationThread : public std::enable_shared_from_this<RegistrationThread>
{ {
Connection& connection_; Connection& connection_;
std::string password_; const Settings &settings_;
std::string username_; std::shared_ptr<SelfThread> self_;
std::string realname_;
std::string nickname_;
std::unordered_map<std::string, std::string> caps; std::unordered_map<std::string, std::string> caps;
std::unordered_set<std::string> outstanding; std::unordered_set<std::string> outstanding;
@ -32,17 +32,13 @@ class RegistrationThread : public std::enable_shared_from_this<RegistrationThrea
public: public:
RegistrationThread( RegistrationThread(
Connection& connection_, Connection& connection_,
std::string password, const Settings &,
std::string username, std::shared_ptr<SelfThread> self
std::string realname,
std::string nickname
); );
static auto start( static auto start(
Connection& connection, Connection& connection,
std::string password, const Settings &,
std::string username, std::shared_ptr<SelfThread> self
std::string realname,
std::string nickname
) -> std::shared_ptr<RegistrationThread>; ) -> std::shared_ptr<RegistrationThread>;
}; };

19
sasl_mechanism.cpp Normal file
View File

@ -0,0 +1,19 @@
#include "sasl_mechanism.hpp"
auto SaslPlain::step(std::string_view msg) -> std::optional<std::string> {
if (complete_) {
return std::nullopt;
} else {
std::string reply;
reply += authzid_;
reply += '\0';
reply += authcid_;
reply += '\0';
reply += password_;
complete_ = true;
return {std::move(reply)};
}
}

50
sasl_mechanism.hpp Normal file
View File

@ -0,0 +1,50 @@
#pragma once
#include <boost/signals2.hpp>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include "event.hpp"
struct Connection;
class SaslMechanism
{
public:
virtual ~SaslMechanism() {}
virtual auto mechanism_name() const -> std::string = 0;
virtual auto step(std::string_view msg) -> std::optional<std::string> = 0;
virtual auto is_complete() const -> bool = 0;
};
class SaslPlain final : public SaslMechanism
{
std::string authcid_;
std::string authzid_;
std::string password_;
bool complete_;
public:
SaslPlain(std::string authcid, std::string authzid, std::string password)
: authcid_{std::move(authcid)}
, authzid_{std::move(authzid)}
, password_{std::move(password)}
, complete_{false}
{}
auto mechanism_name() const -> std::string override
{
return "PLAIN";
}
auto step(std::string_view msg) -> std::optional<std::string> override;
auto is_complete() const -> bool override
{
return complete_;
}
};

View File

@ -1,55 +0,0 @@
#include "sasl_thread.hpp"
#include <mybase64.hpp>
#include <boost/log/trivial.hpp>
#include "connection.hpp"
#include "write_irc.hpp"
#include "irc_parse_thread.hpp"
#include "ircmsg.hpp"
auto SaslThread::start(Connection& connection) -> std::shared_ptr<SaslThread>
{
auto thread = std::make_shared<SaslThread>(connection);
connection.add_listener<IrcMsgEvent>([thread](IrcMsgEvent const& event){
if (event.command == IrcCommand::AUTHENTICATE)
{
thread->on_authenticate(event.irc.args[0]);
}
});
return thread;
}
auto SaslThread::on_authenticate(std::string_view chunk) -> void
{
if (chunk != "+") {
buffer_ += chunk;
}
if (chunk.size() != 400)
{
std::string decoded;
decoded.resize(mybase64::decoded_size(buffer_.size()));
std::size_t len;
if (mybase64::decode(buffer_, decoded.data(), &len))
{
decoded.resize(len);
connection_.make_event<SaslMessage>(std::move(decoded));
} else {
BOOST_LOG_TRIVIAL(debug) << "Invalid AUTHENTICATE base64";
send_authenticate(connection_, "*"); // abort SASL
}
buffer_.clear();
}
else if (buffer_.size() > MAX_BUFFER)
{
BOOST_LOG_TRIVIAL(debug) << "AUTHENTICATE buffer overflow";
buffer_.clear();
send_authenticate(connection_, "*"); // abort SASL
}
}

View File

@ -1,89 +0,0 @@
#pragma once
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include "event.hpp"
struct Connection;
class SaslMechanism
{
public:
virtual ~SaslMechanism() {}
virtual auto mechanism_name() const -> std::string = 0;
virtual auto step(std::string_view msg) -> std::optional<std::string> = 0;
virtual auto is_complete() const -> bool = 0;
};
class SaslPlain final : public SaslMechanism
{
std::string authcid_;
std::string authzid_;
std::string password_;
bool complete_;
public:
SaslPlain(std::string username, std::string password)
: authcid_{std::move(username)}
, password_{std::move(password)}
, complete_{false}
{}
auto mechanism_name() const -> std::string override
{
return "PLAIN";
}
auto step(std::string_view msg) -> std::optional<std::string> override {
if (complete_) {
return {};
} else {
std::string reply;
reply += authzid_;
reply += '\0';
reply += authcid_;
reply += '\0';
reply += password_;
complete_ = true;
return {std::move(reply)};
}
}
auto is_complete() const -> bool override
{
return complete_;
}
};
struct SaslComplete : Event
{
bool success;
};
struct SaslMessage : Event
{
SaslMessage(std::string message) : message{std::move(message)} {}
std::string message;
};
class SaslThread
{
Connection& connection_;
std::string buffer_;
const std::size_t MAX_BUFFER = 1024;
public:
SaslThread(Connection& connection) : connection_{connection} {}
static auto start(Connection& connection) -> std::shared_ptr<SaslThread>;
auto on_authenticate(std::string_view) -> void;
};

View File

@ -2,7 +2,12 @@
#include "connection.hpp" #include "connection.hpp"
#include <mybase64.hpp>
#include <boost/container/flat_map.hpp> #include <boost/container/flat_map.hpp>
#include <boost/log/trivial.hpp>
using namespace std::literals;
auto SelfThread::on_welcome(IrcMsg const& irc) -> void auto SelfThread::on_welcome(IrcMsg const& irc) -> void
{ {
@ -86,12 +91,14 @@ auto SelfThread::on_isupport(const IrcMsg &msg) -> void
for (int i = 1; i < hi; ++i) for (int i = 1; i < hi; ++i)
{ {
auto &entry = msg.args[i]; auto &entry = msg.args[i];
// Leading minus means to stop support
if (entry.starts_with("-")) { if (entry.starts_with("-")) {
auto key = std::string{entry.substr(1)}; auto const key = std::string{entry.substr(1)};
if (auto cursor = isupport_.find(key); cursor != isupport_.end()) { if (auto cursor = isupport_.find(key); cursor != isupport_.end()) {
isupport_.erase(cursor); isupport_.erase(cursor);
} }
} else if (auto cursor = entry.find('='); cursor != entry.npos) { } else if (auto const cursor = entry.find('='); cursor != entry.npos) {
isupport_.emplace(entry.substr(0, cursor), entry.substr(cursor+1)); isupport_.emplace(entry.substr(0, cursor), entry.substr(cursor+1));
} else { } else {
isupport_.emplace(entry, std::string{}); isupport_.emplace(entry, std::string{});
@ -107,18 +114,22 @@ auto SelfThread::start(Connection& connection) -> std::shared_ptr<SelfThread>
{ {
switch (cmd) switch (cmd)
{ {
case IrcCommand::RPL_WELCOME: thread->on_welcome(msg); break;
case IrcCommand::RPL_ISUPPORT: thread->on_isupport(msg); break;
case IrcCommand::RPL_UMODEIS: thread->on_umodeis(msg); break;
case IrcCommand::NICK: thread->on_nick(msg); break;
case IrcCommand::JOIN: thread->on_join(msg); break; case IrcCommand::JOIN: thread->on_join(msg); break;
case IrcCommand::KICK: thread->on_kick(msg); break; case IrcCommand::KICK: thread->on_kick(msg); break;
case IrcCommand::PART: thread->on_part(msg); break;
case IrcCommand::MODE: thread->on_mode(msg); break; case IrcCommand::MODE: thread->on_mode(msg); break;
case IrcCommand::NICK: thread->on_nick(msg); break;
case IrcCommand::PART: thread->on_part(msg); break;
case IrcCommand::RPL_ISUPPORT: thread->on_isupport(msg); break;
case IrcCommand::RPL_UMODEIS: thread->on_umodeis(msg); break;
case IrcCommand::RPL_WELCOME: thread->on_welcome(msg); break;
default: break; default: break;
} }
}); });
connection.sig_authenticate.connect([thread](auto msg) {
thread->on_authenticate(msg);
});
return thread; return thread;
} }
@ -147,3 +158,34 @@ auto SelfThread::is_my_mask(std::string_view mask) const -> bool
auto const bang = mask.find('!'); auto const bang = mask.find('!');
return bang != std::string_view::npos && nickname_ == mask.substr(0, bang); return bang != std::string_view::npos && nickname_ == mask.substr(0, bang);
} }
auto SelfThread::on_authenticate(const std::string_view body) -> void
{
if (not sasl_mechanism_)
{
BOOST_LOG_TRIVIAL(warning) << "Unexpected AUTHENTICATE from server"sv;
connection_.send_authenticate_abort();
return;
}
if (auto reply = sasl_mechanism_->step(body)) {
connection_.send_authenticate_encoded(*reply);
// Clean up completed SASL transactions
if (sasl_mechanism_->is_complete())
{
sasl_mechanism_.reset();
}
}
}
auto SelfThread::start_sasl(std::unique_ptr<SaslMechanism> mechanism) -> void
{
if (sasl_mechanism_) {
connection_.send_authenticate("*"sv); // abort SASL
}
sasl_mechanism_ = std::move(mechanism);
connection_.send_authenticate(sasl_mechanism_->mechanism_name());
}

View File

@ -1,8 +1,7 @@
#pragma once #pragma once
#include "connection.hpp" #include "connection.hpp"
#include "sasl_mechanism.hpp"
#include <boost/container/flat_map.hpp>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
@ -17,11 +16,16 @@ struct IrcMsg;
class SelfThread class SelfThread
{ {
Connection& connection_; Connection& connection_;
std::string nickname_; std::string nickname_;
std::string mode_; std::string mode_;
std::unordered_set<std::string> channels_; std::unordered_set<std::string> channels_;
// RPL_ISUPPORT state
std::unordered_map<std::string, std::string> isupport_; std::unordered_map<std::string, std::string> isupport_;
std::unique_ptr<SaslMechanism> sasl_mechanism_;
auto on_welcome(IrcMsg const& irc) -> void; auto on_welcome(IrcMsg const& irc) -> void;
auto on_isupport(IrcMsg const& irc) -> void; auto on_isupport(IrcMsg const& irc) -> void;
auto on_nick(IrcMsg const& irc) -> void; auto on_nick(IrcMsg const& irc) -> void;
@ -30,15 +34,19 @@ class SelfThread
auto on_kick(IrcMsg const& irc) -> void; auto on_kick(IrcMsg const& irc) -> void;
auto on_part(IrcMsg const& irc) -> void; auto on_part(IrcMsg const& irc) -> void;
auto on_mode(IrcMsg const& irc) -> void; auto on_mode(IrcMsg const& irc) -> void;
auto on_authenticate(std::string_view) -> void;
public: public:
SelfThread(Connection& connection) : connection_{connection} {} SelfThread(Connection& connection) : connection_{connection} {}
static auto start(Connection&) -> std::shared_ptr<SelfThread>; static auto start(Connection&) -> std::shared_ptr<SelfThread>;
auto start_sasl(std::unique_ptr<SaslMechanism> mechanism) -> void;
auto get_my_nickname() const -> std::string const&; auto get_my_nickname() const -> std::string const&;
auto get_my_mode() const -> std::string const&; auto get_my_mode() const -> std::string const&;
auto get_my_channels() const -> std::unordered_set<std::string> const&; auto get_my_channels() const -> std::unordered_set<std::string> const&;
auto is_my_nick(std::string_view nick) const -> bool; auto is_my_nick(std::string_view nick) const -> bool;
auto is_my_mask(std::string_view nick) const -> bool; auto is_my_mask(std::string_view nick) const -> bool;
}; };

View File

@ -7,11 +7,15 @@ auto Settings::from_stream(std::istream & in) -> Settings
{ {
auto const config = toml::parse(in); auto const config = toml::parse(in);
return Settings{ return Settings{
.host = config["host"].value_or(std::string{"*"}), .host = config["host"].value_or(std::string{}),
.service = config["service"].value_or(std::string{"*"}), .service = config["service"].value_or(std::string{}),
.password = config["password"].value_or(std::string{"*"}), .password = config["password"].value_or(std::string{}),
.username = config["username"].value_or(std::string{"*"}), .username = config["username"].value_or(std::string{}),
.realname = config["realname"].value_or(std::string{"*"}), .realname = config["realname"].value_or(std::string{}),
.nickname = config["nickname"].value_or(std::string{"*"}) .nickname = config["nickname"].value_or(std::string{}),
.sasl_mechanism = config["sasl_mechanism"].value_or(std::string{}),
.sasl_authcid = config["sasl_authcid"].value_or(std::string{}),
.sasl_authzid = config["sasl_authzid"].value_or(std::string{}),
.sasl_password = config["sasl_password"].value_or(std::string{})
}; };
} }

View File

@ -12,6 +12,11 @@ struct Settings
std::string realname; std::string realname;
std::string nickname; std::string nickname;
std::string sasl_authcid;
std::string sasl_authzid;
std::string sasl_password;
std::string sasl_mechanism;
static auto from_stream(std::istream & in) -> Settings; static auto from_stream(std::istream & in) -> Settings;
}; };