copy over snowcone's stream

This commit is contained in:
Eric Mertens 2025-01-26 14:38:13 -08:00
parent 8be5332692
commit ebe884e9d5
10 changed files with 860 additions and 60 deletions

View File

@ -6,10 +6,11 @@ project(xbot
)
find_package(PkgConfig REQUIRED)
find_package(OpenSSL REQUIRED)
pkg_check_modules(LIBHS libhs REQUIRED IMPORTED_TARGET)
set(BOOST_INCLUDE_LIBRARIES asio log signals2)
set(BOOST_INCLUDE_LIBRARIES asio log signals2 endian)
set(BOOST_ENABLE_CMAKE ON)
include(FetchContent)
FetchContent_Declare(
@ -36,6 +37,7 @@ add_custom_command(
VERBATIM)
add_subdirectory(mybase64)
add_subdirectory(mysocks5)
add_executable(xbot
main.cpp
@ -52,4 +54,9 @@ add_executable(xbot
)
target_include_directories(xbot PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(xbot PRIVATE Boost::signals2 Boost::log Boost::asio tomlplusplus_tomlplusplus PkgConfig::LIBHS mybase64)
target_link_libraries(xbot PRIVATE
OpenSSL::SSL
Boost::signals2 Boost::log Boost::asio
tomlplusplus_tomlplusplus
PkgConfig::LIBHS
mysocks5 mybase64)

View File

@ -8,6 +8,9 @@
namespace {
#include "irc_commands.inc"
using tcp_type = boost::asio::ip::tcp::socket;
using tls_type = boost::asio::ssl::stream<tcp_type>;
} // namespace
using namespace std::literals;
@ -48,50 +51,6 @@ auto Connection::write_buffers() -> void
write_strings_.clear();
}
auto Connection::connect(
boost::asio::io_context &io,
std::string host,
std::string port
) -> boost::asio::awaitable<void>
{
using namespace std::placeholders;
// keep connection alive while coroutine is active
const auto self = shared_from_this();
{
auto resolver = boost::asio::ip::tcp::resolver{io};
const auto endpoints = co_await resolver.async_resolve(host, port, boost::asio::use_awaitable);
const auto endpoint = co_await boost::asio::async_connect(stream_, endpoints, boost::asio::use_awaitable);
BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << endpoint;
sig_connect();
}
watchdog();
for (LineBuffer buffer{32'768};;)
{
boost::system::error_code error;
const auto n = co_await stream_.async_read_some(buffer.get_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, error));
if (error)
{
break;
}
buffer.add_bytes(n, [this](char *line) {
BOOST_LOG_TRIVIAL(debug) << "RECV: " << line;
watchdog_activity();
dispatch_line(line);
});
}
watchdog_timer_.cancel();
stream_.close();
BOOST_LOG_TRIVIAL(debug) << "DISCONNECTED";
sig_disconnect();
}
auto Connection::watchdog() -> void
{
watchdog_timer_.expires_after(watchdog_duration);
@ -326,3 +285,177 @@ auto Connection::send_authenticate_encoded(std::string_view body) -> void
send_authenticate("+"sv);
}
}
static
auto set_buffer_size(tls_type& stream, std::size_t const n) -> void
{
auto const ssl = stream.native_handle();
BIO_set_buffer_size(SSL_get_rbio(ssl), n);
BIO_set_buffer_size(SSL_get_wbio(ssl), n);
}
static
auto set_buffer_size(tcp_type& socket, std::size_t const n) -> void
{
socket.set_option(tcp_type::send_buffer_size{static_cast<int>(n)});
socket.set_option(tcp_type::receive_buffer_size{static_cast<int>(n)});
}
static
auto set_cloexec(int const fd) -> void
{
auto const flags = fcntl(fd, F_GETFD);
if (-1 == flags)
{
throw std::system_error{errno, std::generic_category(), "failed to get file descriptor flags"};
}
if (-1 == fcntl(fd, F_SETFD, flags | FD_CLOEXEC))
{
throw std::system_error{errno, std::generic_category(), "failed to set file descriptor flags"};
}
}
template <std::size_t... Ns>
static
auto constexpr sum() -> std::size_t { return (0 + ... + Ns); }
/**
* @brief Build's the string format required for the ALPN extension
*
* @tparam Ns sizes of each protocol name
* @param protocols array of the names of the supported protocols
* @return encoded protocol names
*/
template <std::size_t... Ns>
static
auto constexpr alpn_encode(char const (&... protocols)[Ns]) -> std::array<unsigned char, sum<Ns...>()>
{
auto result = std::array<unsigned char, sum<Ns...>()>{};
auto cursor = std::begin(result);
auto const encode = [&cursor]<std::size_t N>(char const(&protocol)[N]) {
static_assert(N > 0, "Protocol name must be null-terminated");
static_assert(N < 256, "Protocol name too long");
if (protocol[N - 1] != '\0')
throw "Protocol name not null-terminated";
// Prefixed length byte
*cursor++ = N - 1;
// Add string skipping null terminator
cursor = std::copy(std::begin(protocol), std::end(protocol) - 1, cursor);
};
(encode(protocols), ...);
return result;
}
/**
* @brief Configure the TLS stream to request the IRC protocol.
*
* @param stream TLS stream
*/
static
auto set_alpn(tls_type& stream) -> void
{
auto constexpr protos = alpn_encode("irc");
SSL_set_alpn_protos(stream.native_handle(), protos.data(), protos.size());
}
static
auto build_ssl_context(
X509* client_cert,
EVP_PKEY* client_key
) -> boost::asio::ssl::context
{
boost::asio::ssl::context ssl_context{boost::asio::ssl::context::method::tls_client};
ssl_context.set_default_verify_paths();
if (nullptr != client_cert)
{
if (1 != SSL_CTX_use_certificate(ssl_context.native_handle(), client_cert))
{
throw std::runtime_error{"certificate file"};
}
}
if (nullptr != client_key)
{
if (1 != SSL_CTX_use_PrivateKey(ssl_context.native_handle(), client_key))
{
throw std::runtime_error{"private key"};
}
}
return ssl_context;
}
auto Connection::connect(
boost::asio::io_context &io,
ConnectSettings settings
) -> boost::asio::awaitable<void>
{
using namespace std::placeholders;
// keep connection alive while coroutine is active
const auto self = shared_from_this();
const size_t irc_buffer_size = 32'768;
auto& socket = stream_.reset();
{
auto resolver = boost::asio::ip::tcp::resolver{io};
const auto endpoints = co_await resolver.async_resolve(settings.host, std::to_string(settings.port), boost::asio::use_awaitable);
const auto endpoint = co_await boost::asio::async_connect(socket, endpoints, boost::asio::use_awaitable);
BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << endpoint;
socket.set_option(boost::asio::ip::tcp::no_delay(true));
set_buffer_size(socket, irc_buffer_size);
set_cloexec(socket.native_handle());
}
if (settings.tls)
{
auto cxt = build_ssl_context(settings.client_cert.get(), settings.client_key.get());
// Upgrade stream_ to use TLS and invalidate socket
auto& stream = stream_.upgrade(cxt);
set_buffer_size(stream, irc_buffer_size);
set_alpn(stream);
if (not settings.verify.empty())
{
stream.set_verify_mode(boost::asio::ssl::verify_peer);
stream.set_verify_callback(boost::asio::ssl::host_name_verification(settings.verify));
}
if (not settings.sni.empty())
{
SSL_set_tlsext_host_name(stream.native_handle(), settings.sni.c_str());
}
co_await stream.async_handshake(stream.client, boost::asio::use_awaitable);
}
sig_connect();
watchdog();
for (LineBuffer buffer{irc_buffer_size};;)
{
boost::system::error_code error;
const auto n = co_await stream_.async_read_some(buffer.get_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, error));
if (error)
{
break;
}
buffer.add_bytes(n, [this](char *line) {
BOOST_LOG_TRIVIAL(debug) << "RECV: " << line;
watchdog_activity();
dispatch_line(line);
});
}
watchdog_timer_.cancel();
stream_.close();
BOOST_LOG_TRIVIAL(debug) << "DISCONNECTED";
sig_disconnect();
}

View File

@ -2,7 +2,9 @@
#include "irc_command.hpp"
#include "ircmsg.hpp"
#include "settings.hpp"
#include "snote.hpp"
#include "stream.hpp"
#include <boost/asio.hpp>
#include <boost/signals2.hpp>
@ -11,10 +13,36 @@
#include <memory>
#include <string>
template <typename T, int(*UpRef)(T*), void(*Free)(T*)>
class Ref {
struct Deleter { auto operator()(auto ptr) { Free(ptr); }};
std::unique_ptr<T, Deleter> obj;
public:
Ref() = default;
Ref(T* t) : obj{t} { if (t) UpRef(t); }
auto get() const -> T* { return obj.get(); }
};
struct ConnectSettings
{
bool tls;
std::string host;
std::uint16_t port;
Ref<X509, X509_up_ref, X509_free> client_cert;
Ref<EVP_PKEY, EVP_PKEY_up_ref, EVP_PKEY_free> client_key;
std::string verify;
std::string sni;
std::string socks_host;
std::uint16_t socks_port;
std::string socks_user;
std::string socks_pass;
};
class Connection : public std::enable_shared_from_this<Connection>
{
private:
boost::asio::ip::tcp::socket stream_;
Stream stream_;
boost::asio::steady_timer watchdog_timer_;
std::list<std::string> write_strings_;
bool write_posted_;
@ -58,8 +86,7 @@ public:
auto connect(
boost::asio::io_context &io,
std::string host,
std::string port
ConnectSettings settings
) -> boost::asio::awaitable<void>;
auto close() -> void;

View File

@ -2,15 +2,16 @@
#include "settings.hpp"
#include <boost/asio.hpp>
#include <boost/log/trivial.hpp>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "bot.hpp"
#include "registration_thread.hpp"
#include "self_thread.hpp"
#include "bot.hpp"
using namespace std::chrono_literals;
@ -31,16 +32,29 @@ auto start(boost::asio::io_context &io, const Settings &settings) -> void
});
boost::asio::co_spawn(
io, connection->connect(io, settings.host, settings.service),
[&io, &settings](std::exception_ptr e) {
io, connection->connect(io, ConnectSettings{
.tls = settings.use_tls,
.host = settings.host,
.port = settings.service,
.verify = settings.tls_hostname,
}),
[&io, &settings, connection](std::exception_ptr e) {
try
{
if (e)
std::rethrow_exception(e);
}
catch (const std::exception &e)
{
BOOST_LOG_TRIVIAL(debug) << "TERMINATED: " << e.what();
}
auto timer = std::make_shared<boost::asio::steady_timer>(io);
timer->expires_after(5s);
timer->async_wait(
[&io, &settings, timer](auto) { start(io, settings); }
);
}
);
[&io, &settings, timer](auto) { start(io, settings); });
});
}
auto get_settings() -> Settings
@ -51,7 +65,7 @@ auto get_settings() -> Settings
}
else
{
std::cerr << "Unable to open config.toml\n";
BOOST_LOG_TRIVIAL(error) << "Unable to open config.toml";
std::exit(1);
}
}

3
mysocks5/CMakeLists.txt Normal file
View File

@ -0,0 +1,3 @@
add_library(mysocks5 STATIC socks5.cpp)
target_include_directories(mysocks5 PUBLIC include)
target_link_libraries(mysocks5 PUBLIC Boost::asio Boost::endian)

404
mysocks5/include/socks5.hpp Normal file
View File

@ -0,0 +1,404 @@
#pragma once
#include <boost/asio.hpp>
#include <boost/endian.hpp>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <string>
#include <string_view>
#include <variant>
namespace socks5 {
struct SocksErrCategory : boost::system::error_category
{
char const* name() const noexcept override;
std::string message(int) const override;
};
extern SocksErrCategory const theSocksErrCategory;
enum class SocksErrc
{
// Errors from the server
Succeeded = 0,
GeneralFailure = 1,
NotAllowed = 2,
NetworkUnreachable = 3,
HostUnreachable = 4,
ConnectionRefused = 5,
TtlExpired = 6,
CommandNotSupported = 7,
AddressNotSupported = 8,
// Errors from the client
WrongVersion = 256,
NoAcceptableMethods,
AuthenticationFailed,
UnsupportedEndpointAddress,
DomainTooLong,
UsernameTooLong,
PasswordTooLong,
};
/// Either a hostname or an address. Hostnames are resolved locally on the proxy server
using Host = std::variant<std::string_view, boost::asio::ip::address>;
struct NoCredential
{
};
struct UsernamePasswordCredential
{
std::string_view username;
std::string_view password;
};
using Auth = std::variant<NoCredential, UsernamePasswordCredential>;
namespace detail {
template <class... Ts>
struct overloaded : Ts...
{
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
auto make_socks_error(SocksErrc const err) -> boost::system::error_code;
inline auto push_buffer(std::vector<std::uint8_t>& buffer, auto const& thing) -> void
{
buffer.push_back(thing.size());
buffer.insert(buffer.end(), thing.begin(), thing.end());
}
uint8_t const socks_version_tag = 5;
uint8_t const auth_version_tag = 1;
enum class AuthMethod
{
NoAuth = 0,
Gssapi = 1,
UsernamePassword = 2,
NoAcceptableMethods = 255,
};
enum class Command
{
Connect = 1,
Bind = 2,
UdpAssociate = 3,
};
enum class AddressType
{
IPv4 = 1,
DomainName = 3,
IPv6 = 4,
};
/// @brief Encode the given host into the end of the buffer
/// @param host host to encode
/// @param buffer target to push bytes onto
/// @return true for success and false for failure
auto push_host(Host const& host, std::vector<uint8_t>& buffer) -> void;
template <typename AsyncStream>
struct SocksImplementation
{
AsyncStream& socket_;
Host const host_;
boost::endian::big_uint16_t const port_;
Auth const auth_;
/// buffer used to back async read/write operations
std::vector<uint8_t> buffer_;
// Representations of states in the protocol
struct Start
{
};
struct HelloRecvd
{
static const std::size_t READ = 2; // version, method
};
struct AuthRecvd
{
static const std::size_t READ = 2; // subversion, status
};
struct ReplyRecvd
{
static const std::size_t READ = 4; // version, reply, reserved, address-tag
};
struct FinishIpv4
{
static const std::size_t READ = 6; // ipv4 + port = 6 bytes
};
struct FinishIpv6
{
static const std::size_t READ = 18; // ipv6 + port = 18 bytes
};
/// @brief State when the application needs to receive some bytes
/// @tparam Next State to transistion to after read is successful
template <typename Next>
struct Sent
{
};
/// @brief intermediate completion callback
/// @tparam Self type of enclosing intermediate completion handler
/// @tparam State protocol state tag type
/// @param self enclosing intermediate completion handler
/// @param state protocol state tag value
/// @param error error code of read/write operation
/// @param size bytes read or written
/// @param
template <typename Self, typename State = Start>
auto operator()(
Self& self,
State state = {},
boost::system::error_code const error = {},
std::size_t = 0
) -> void
{
if (error)
{
self.complete(error, {});
}
else
{
step(self, state);
}
}
/// @brief Write the buffer to the socket and then read N bytes back into the buffer
/// @tparam Next state to resume after read
/// @tparam N number of bytes to read
/// @tparam Self type of enclosing intermediate completion handler
/// @param self enclosing intermediate completion handler
template <typename Next, typename Self>
auto transact(Self& self) -> void
{
boost::asio::async_write(
socket_,
boost::asio::buffer(buffer_),
[self = std::move(self)](boost::system::error_code const err, std::size_t const n) mutable {
self(Sent<Next>{}, err, n);
}
);
}
/// @brief Notify the caller of a failure and terminate the protocol
/// @param self intermediate completion handler
/// @param err error code to return to the caller
auto failure(auto& self, SocksErrc const err) -> void
{
self.complete(make_socks_error(err), {});
}
/// @brief Read bytes needed by Next state from the socket and then proceed to Next state
/// @tparam Self type of enclosing intermediate completion handler
/// @tparam Next state to transition to after read
/// @param self enclosing intermediate completion handler
/// @param state protocol state tag
template <typename Self, typename Next>
auto step(Self& self, Sent<Next>) -> void
{
buffer_.resize(Next::READ);
boost::asio::async_read(
socket_,
boost::asio::buffer(buffer_),
[self = std::move(self)](boost::system::error_code const err, std::size_t n) mutable {
self(Next{}, err, n);
}
);
}
// Send hello and offer authentication methods
template <typename Self>
auto step(Self& self, Start) -> void
{
if (auto const* const host = std::get_if<std::string_view>(&host_))
{
if (host->size() >= 256)
{
return failure(self, SocksErrc::DomainTooLong);
}
}
if (auto const* const plain = std::get_if<UsernamePasswordCredential>(&auth_))
{
if (plain->username.size() >= 256)
{
return failure(self, SocksErrc::UsernameTooLong);
}
if (plain->password.size() >= 256)
{
return failure(self, SocksErrc::PasswordTooLong);
}
}
buffer_ = {socks_version_tag, 1 /* number of methods */, static_cast<uint8_t>(method_wanted())};
transact<HelloRecvd>(self);
}
// Send TCP connection request for the domain name and port
template <typename Self>
auto step(Self& self, HelloRecvd) -> void
{
if (socks_version_tag != buffer_[0])
{
return failure(self, SocksErrc::WrongVersion);
}
auto const wanted = method_wanted();
auto const selected = static_cast<AuthMethod>(buffer_[1]);
if (AuthMethod::NoAuth == wanted && wanted == selected)
{
send_connect(self);
}
else if (AuthMethod::UsernamePassword == wanted && wanted == selected)
{
send_usernamepassword(self);
}
else
{
failure(self, SocksErrc::NoAcceptableMethods);
}
}
/// @brief Transmit the username and password to the server
/// @tparam Self type of enclosing intermediate completion handler
/// @param self enclosing intermediate completion handler
template <typename Self>
auto send_usernamepassword(Self& self) -> void
{
buffer_ = {
auth_version_tag,
};
auto const [username, password] = std::get<1>(auth_);
push_buffer(buffer_, username);
push_buffer(buffer_, password);
transact<AuthRecvd>(self);
}
template <typename Self>
auto step(Self& self, AuthRecvd) -> void
{
if (auth_version_tag != buffer_[0])
{
return failure(self, SocksErrc::WrongVersion);
}
// STATUS zero is success, non-zero is failure
if (0 != buffer_[1])
{
return failure(self, SocksErrc::AuthenticationFailed);
}
send_connect(self);
}
template <typename Self>
auto send_connect(Self& self) -> void
{
buffer_ = {
socks_version_tag,
static_cast<uint8_t>(Command::Connect),
0 /* reserved */,
};
push_host(host_, buffer_);
buffer_.insert(buffer_.end(), port_.data(), port_.data() + 2);
transact<ReplyRecvd>(self);
}
// Waiting on the remaining variable-sized address portion of the response
template <typename Self>
auto step(Self& self, ReplyRecvd) -> void
{
if (socks_version_tag != buffer_[0])
{
return failure(self, SocksErrc::WrongVersion);
}
auto const reply = static_cast<SocksErrc>(buffer_[1]);
if (SocksErrc::Succeeded != reply)
{
return failure(self, reply);
}
switch (static_cast<AddressType>(buffer_[3]))
{
case AddressType::IPv4:
return step(self, Sent<FinishIpv4>{});
case AddressType::IPv6:
return step(self, Sent<FinishIpv6>{});
default:
return failure(self, SocksErrc::UnsupportedEndpointAddress);
}
}
// Protocol complete! Return the client's remote endpoint
template <typename Self>
void step(Self& self, FinishIpv4)
{
boost::asio::ip::address_v4::bytes_type bytes;
boost::endian::big_uint16_t port;
std::memcpy(bytes.data(), &buffer_[0], 4);
std::memcpy(port.data(), &buffer_[4], 2);
self.complete({}, {boost::asio::ip::make_address_v4(bytes), port});
}
// Protocol complete! Return the client's remote endpoint
template <typename Self>
void step(Self& self, FinishIpv6)
{
boost::asio::ip::address_v6::bytes_type bytes;
boost::endian::big_uint16_t port;
std::memcpy(bytes.data(), &buffer_[0], 16);
std::memcpy(port.data(), &buffer_[16], 2);
self.complete({}, {boost::asio::ip::make_address_v6(bytes), port});
}
auto method_wanted() const -> AuthMethod
{
return std::visit(
overloaded{
[](NoCredential) { return AuthMethod::NoAuth; },
[](UsernamePasswordCredential) { return AuthMethod::UsernamePassword; },
},
auth_
);
}
};
} // namespace detail
using Signature = void(boost::system::error_code, boost::asio::ip::tcp::endpoint);
/// @brief Asynchronous SOCKS5 connection request
/// @tparam AsyncStream Type of socket
/// @tparam CompletionToken Token accepting: error_code, address, port
/// @param socket Established connection to SOCKS5 server
/// @param host Connection target host
/// @param port Connection target port
/// @param token Completion token
/// @return Behavior determined by completion token type
template <
typename AsyncStream,
boost::asio::completion_token_for<Signature> CompletionToken>
auto async_connect(
AsyncStream& socket,
Host const host,
uint16_t const port,
Auth const auth,
CompletionToken&& token
)
{
return boost::asio::async_compose<CompletionToken, Signature>(detail::SocksImplementation<AsyncStream>{socket, host, port, auth, {}}, token, socket);
}
} // namespace socks5

93
mysocks5/socks5.cpp Normal file
View File

@ -0,0 +1,93 @@
#include "socks5.hpp"
#include <boost/asio.hpp>
#include <iterator>
#include <stdexcept>
#include <vector>
namespace socks5 {
SocksErrCategory const theSocksErrCategory;
char const* SocksErrCategory::name() const noexcept
{
return "socks5";
}
std::string SocksErrCategory::message(int ev) const
{
switch (static_cast<SocksErrc>(ev))
{
case SocksErrc::Succeeded:
return "succeeded";
case SocksErrc::GeneralFailure:
return "general SOCKS server failure";
case SocksErrc::NotAllowed:
return "connection not allowed by ruleset";
case SocksErrc::NetworkUnreachable:
return "network unreachable";
case SocksErrc::HostUnreachable:
return "host unreachable";
case SocksErrc::ConnectionRefused:
return "connection refused";
case SocksErrc::TtlExpired:
return "TTL expired";
case SocksErrc::CommandNotSupported:
return "command not supported";
case SocksErrc::AddressNotSupported:
return "address type not supported";
case SocksErrc::WrongVersion:
return "bad server protocol version";
case SocksErrc::NoAcceptableMethods:
return "server rejected authentication methods";
case SocksErrc::AuthenticationFailed:
return "server rejected authentication";
case SocksErrc::UnsupportedEndpointAddress:
return "server sent unknown endpoint address";
case SocksErrc::DomainTooLong:
return "domain name too long";
case SocksErrc::UsernameTooLong:
return "username too long";
case SocksErrc::PasswordTooLong:
return "password too long";
default:
return "(unrecognized error)";
}
}
namespace detail {
auto make_socks_error(SocksErrc const err) -> boost::system::error_code
{
return boost::system::error_code{int(err), theSocksErrCategory};
}
auto push_host(Host const& host, std::vector<uint8_t>& buffer) -> void
{
std::visit(overloaded{[&buffer](std::string_view const hostname) {
buffer.push_back(uint8_t(AddressType::DomainName));
push_buffer(buffer, hostname);
},
[&buffer](boost::asio::ip::address const& address) {
if (address.is_v4())
{
buffer.push_back(uint8_t(AddressType::IPv4));
push_buffer(buffer, address.to_v4().to_bytes());
}
else if (address.is_v6())
{
buffer.push_back(uint8_t(AddressType::IPv6));
push_buffer(buffer, address.to_v6().to_bytes());
}
else
{
throw std::logic_error{"unexpected address type"};
}
}},
host);
}
} // namespace detail
} // namespace socks5

View File

@ -8,7 +8,7 @@ auto Settings::from_stream(std::istream &in) -> Settings
const auto config = toml::parse(in);
return Settings{
.host = config["host"].value_or(std::string{}),
.service = config["service"].value_or(std::string{}),
.service = config["port"].value_or(std::uint16_t{6667}),
.password = config["password"].value_or(std::string{}),
.username = config["username"].value_or(std::string{}),
.realname = config["realname"].value_or(std::string{}),
@ -16,6 +16,8 @@ auto Settings::from_stream(std::istream &in) -> Settings
.sasl_mechanism = config["sasl_mechanism"].value_or(std::string{}),
.sasl_authcid = config["sasl_authcid"].value_or(std::string{}),
.sasl_authzid = config["sasl_authzid"].value_or(std::string{}),
.sasl_password = config["sasl_password"].value_or(std::string{})
.sasl_password = config["sasl_password"].value_or(std::string{}),
.tls_hostname = config["tls_hostname"].value_or(std::string{}),
.use_tls = config["use_tls"].value_or(false),
};
}

View File

@ -6,7 +6,7 @@
struct Settings
{
std::string host;
std::string service;
std::uint16_t service;
std::string password;
std::string username;
std::string realname;
@ -17,5 +17,8 @@ struct Settings
std::string sasl_authzid;
std::string sasl_password;
std::string tls_hostname;
bool use_tls;
static auto from_stream(std::istream &in) -> Settings;
};

114
stream.hpp Normal file
View File

@ -0,0 +1,114 @@
#pragma once
#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>
#include <cstddef>
#include <variant>
/// @brief Abstraction over plain-text and TLS streams.
class Stream : private
std::variant<
boost::asio::ip::tcp::socket,
boost::asio::ssl::stream<boost::asio::ip::tcp::socket>>
{
public:
using tcp_socket = boost::asio::ip::tcp::socket;
using tls_stream = boost::asio::ssl::stream<tcp_socket>;
/// @brief The type of the executor associated with the stream.
using executor_type = boost::asio::any_io_executor;
/// @brief Type of the lowest layer of this stream
using lowest_layer_type = tcp_socket::lowest_layer_type;
private:
using base_type = std::variant<tcp_socket, tls_stream>;
auto base() -> base_type& { return *this; }
auto base() const -> base_type const& { return *this; }
public:
/// @brief Initialize stream with a plain TCP socket
/// @param ioc IO context of stream
template <typename T>
Stream(T&& executor) : base_type{std::in_place_type<tcp_socket>, std::forward<T>(executor)} {}
/// @brief Reset stream to a plain TCP socket
/// @return Reference to internal socket object
auto reset() -> tcp_socket&
{
return base().emplace<tcp_socket>(get_executor());
}
/// @brief Upgrade a plain TCP socket into a TLS stream.
/// @param ctx TLS context used for handshake
/// @return Reference to internal stream object
auto upgrade(boost::asio::ssl::context& ctx) -> tls_stream&
{
auto socket = std::move(std::get<tcp_socket>(base()));
return base().emplace<tls_stream>(std::move(socket), ctx);
}
/// @brief Get underlying basic socket
/// @return Reference to underlying socket
auto lowest_layer() -> lowest_layer_type&
{
return std::visit([](auto&& x) -> decltype(auto) { return x.lowest_layer(); }, base());
}
/// @brief Get underlying basic socket
/// @return Reference to underlying socket
auto lowest_layer() const -> lowest_layer_type const&
{
return std::visit([](auto&& x) -> decltype(auto) { return x.lowest_layer(); }, base());
}
/// @brief Get the executor associated with this stream.
/// @return The executor associated with the stream.
auto get_executor() -> executor_type const&
{
return lowest_layer().get_executor();
}
/// @brief Initiates an asynchronous read operation.
/// @tparam MutableBufferSequence Type of the buffer sequence.
/// @tparam Token Type of the completion token.
/// @param buffers The buffer sequence into which data will be read.
/// @param token The completion token for the read operation.
/// @return The result determined by the completion token.
template <
typename MutableBufferSequence,
boost::asio::completion_token_for<void(boost::system::error_code, std::size_t)> Token>
auto async_read_some(MutableBufferSequence&& buffers, Token&& token) -> decltype(auto)
{
return std::visit([&buffers, &token](auto&& x) -> decltype(auto) {
return x.async_read_some(std::forward<MutableBufferSequence>(buffers), std::forward<Token>(token));
}, base());
}
/// @brief Initiates an asynchronous write operation.
/// @tparam ConstBufferSequence Type of the buffer sequence.
/// @tparam Token Type of the completion token.
/// @param buffers The buffer sequence from which data will be written.
/// @param token The completion token for the write operation.
/// @return The result determined by the completion token.
template <
typename ConstBufferSequence,
boost::asio::completion_token_for<void(boost::system::error_code, std::size_t)> Token>
auto async_write_some(ConstBufferSequence&& buffers, Token&& token) -> decltype(auto)
{
return std::visit([&buffers, &token](auto&& x) -> decltype(auto) {
return x.async_write_some(std::forward<ConstBufferSequence>(buffers), std::forward<Token>(token));
}, base());
}
/// @brief Tear down the network stream
auto close() -> void
{
boost::system::error_code err;
auto& socket = lowest_layer();
socket.shutdown(socket.shutdown_both, err);
socket.lowest_layer().close(err);
}
};