diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fe509f..fb8d07a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,10 +6,11 @@ project(xbot ) find_package(PkgConfig REQUIRED) +find_package(OpenSSL REQUIRED) pkg_check_modules(LIBHS libhs REQUIRED IMPORTED_TARGET) -set(BOOST_INCLUDE_LIBRARIES asio log signals2) +set(BOOST_INCLUDE_LIBRARIES asio log signals2 endian) set(BOOST_ENABLE_CMAKE ON) include(FetchContent) FetchContent_Declare( @@ -36,6 +37,7 @@ add_custom_command( VERBATIM) add_subdirectory(mybase64) +add_subdirectory(mysocks5) add_executable(xbot main.cpp @@ -52,4 +54,9 @@ add_executable(xbot ) target_include_directories(xbot PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -target_link_libraries(xbot PRIVATE Boost::signals2 Boost::log Boost::asio tomlplusplus_tomlplusplus PkgConfig::LIBHS mybase64) +target_link_libraries(xbot PRIVATE + OpenSSL::SSL + Boost::signals2 Boost::log Boost::asio + tomlplusplus_tomlplusplus + PkgConfig::LIBHS + mysocks5 mybase64) diff --git a/connection.cpp b/connection.cpp index 451f364..f5b16bd 100644 --- a/connection.cpp +++ b/connection.cpp @@ -8,6 +8,9 @@ namespace { #include "irc_commands.inc" + +using tcp_type = boost::asio::ip::tcp::socket; +using tls_type = boost::asio::ssl::stream; } // namespace using namespace std::literals; @@ -48,50 +51,6 @@ auto Connection::write_buffers() -> void write_strings_.clear(); } -auto Connection::connect( - boost::asio::io_context &io, - std::string host, - std::string port -) -> boost::asio::awaitable -{ - using namespace std::placeholders; - - // keep connection alive while coroutine is active - const auto self = shared_from_this(); - - { - auto resolver = boost::asio::ip::tcp::resolver{io}; - const auto endpoints = co_await resolver.async_resolve(host, port, boost::asio::use_awaitable); - const auto endpoint = co_await boost::asio::async_connect(stream_, endpoints, boost::asio::use_awaitable); - - BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << endpoint; - sig_connect(); - } - - watchdog(); - - for (LineBuffer buffer{32'768};;) - { - boost::system::error_code error; - const auto n = co_await stream_.async_read_some(buffer.get_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, error)); - if (error) - { - break; - } - buffer.add_bytes(n, [this](char *line) { - BOOST_LOG_TRIVIAL(debug) << "RECV: " << line; - watchdog_activity(); - dispatch_line(line); - }); - } - - watchdog_timer_.cancel(); - stream_.close(); - - BOOST_LOG_TRIVIAL(debug) << "DISCONNECTED"; - sig_disconnect(); -} - auto Connection::watchdog() -> void { watchdog_timer_.expires_after(watchdog_duration); @@ -326,3 +285,177 @@ auto Connection::send_authenticate_encoded(std::string_view body) -> void send_authenticate("+"sv); } } + +static +auto set_buffer_size(tls_type& stream, std::size_t const n) -> void +{ + auto const ssl = stream.native_handle(); + BIO_set_buffer_size(SSL_get_rbio(ssl), n); + BIO_set_buffer_size(SSL_get_wbio(ssl), n); +} + +static +auto set_buffer_size(tcp_type& socket, std::size_t const n) -> void +{ + socket.set_option(tcp_type::send_buffer_size{static_cast(n)}); + socket.set_option(tcp_type::receive_buffer_size{static_cast(n)}); +} + +static +auto set_cloexec(int const fd) -> void +{ + auto const flags = fcntl(fd, F_GETFD); + if (-1 == flags) + { + throw std::system_error{errno, std::generic_category(), "failed to get file descriptor flags"}; + } + if (-1 == fcntl(fd, F_SETFD, flags | FD_CLOEXEC)) + { + throw std::system_error{errno, std::generic_category(), "failed to set file descriptor flags"}; + } +} + +template +static +auto constexpr sum() -> std::size_t { return (0 + ... + Ns); } + +/** + * @brief Build's the string format required for the ALPN extension + * + * @tparam Ns sizes of each protocol name + * @param protocols array of the names of the supported protocols + * @return encoded protocol names + */ +template +static +auto constexpr alpn_encode(char const (&... protocols)[Ns]) -> std::array()> +{ + auto result = std::array()>{}; + auto cursor = std::begin(result); + auto const encode = [&cursor](char const(&protocol)[N]) { + static_assert(N > 0, "Protocol name must be null-terminated"); + static_assert(N < 256, "Protocol name too long"); + if (protocol[N - 1] != '\0') + throw "Protocol name not null-terminated"; + + // Prefixed length byte + *cursor++ = N - 1; + + // Add string skipping null terminator + cursor = std::copy(std::begin(protocol), std::end(protocol) - 1, cursor); + }; + (encode(protocols), ...); + return result; +} + +/** + * @brief Configure the TLS stream to request the IRC protocol. + * + * @param stream TLS stream + */ +static +auto set_alpn(tls_type& stream) -> void +{ + auto constexpr protos = alpn_encode("irc"); + SSL_set_alpn_protos(stream.native_handle(), protos.data(), protos.size()); +} + +static +auto build_ssl_context( + X509* client_cert, + EVP_PKEY* client_key +) -> boost::asio::ssl::context +{ + boost::asio::ssl::context ssl_context{boost::asio::ssl::context::method::tls_client}; + ssl_context.set_default_verify_paths(); + + if (nullptr != client_cert) + { + if (1 != SSL_CTX_use_certificate(ssl_context.native_handle(), client_cert)) + { + throw std::runtime_error{"certificate file"}; + } + } + if (nullptr != client_key) + { + if (1 != SSL_CTX_use_PrivateKey(ssl_context.native_handle(), client_key)) + { + throw std::runtime_error{"private key"}; + } + } + return ssl_context; +} + +auto Connection::connect( + boost::asio::io_context &io, + ConnectSettings settings +) -> boost::asio::awaitable +{ + using namespace std::placeholders; + + // keep connection alive while coroutine is active + const auto self = shared_from_this(); + const size_t irc_buffer_size = 32'768; + + auto& socket = stream_.reset(); + + { + auto resolver = boost::asio::ip::tcp::resolver{io}; + const auto endpoints = co_await resolver.async_resolve(settings.host, std::to_string(settings.port), boost::asio::use_awaitable); + const auto endpoint = co_await boost::asio::async_connect(socket, endpoints, boost::asio::use_awaitable); + BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << endpoint; + + socket.set_option(boost::asio::ip::tcp::no_delay(true)); + set_buffer_size(socket, irc_buffer_size); + set_cloexec(socket.native_handle()); + } + + if (settings.tls) + { + auto cxt = build_ssl_context(settings.client_cert.get(), settings.client_key.get()); + + // Upgrade stream_ to use TLS and invalidate socket + auto& stream = stream_.upgrade(cxt); + + set_buffer_size(stream, irc_buffer_size); + set_alpn(stream); + + if (not settings.verify.empty()) + { + stream.set_verify_mode(boost::asio::ssl::verify_peer); + stream.set_verify_callback(boost::asio::ssl::host_name_verification(settings.verify)); + } + + if (not settings.sni.empty()) + { + SSL_set_tlsext_host_name(stream.native_handle(), settings.sni.c_str()); + } + + co_await stream.async_handshake(stream.client, boost::asio::use_awaitable); + } + + sig_connect(); + + watchdog(); + + for (LineBuffer buffer{irc_buffer_size};;) + { + boost::system::error_code error; + const auto n = co_await stream_.async_read_some(buffer.get_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, error)); + if (error) + { + break; + } + buffer.add_bytes(n, [this](char *line) { + BOOST_LOG_TRIVIAL(debug) << "RECV: " << line; + watchdog_activity(); + dispatch_line(line); + }); + } + + watchdog_timer_.cancel(); + stream_.close(); + + BOOST_LOG_TRIVIAL(debug) << "DISCONNECTED"; + sig_disconnect(); +} diff --git a/connection.hpp b/connection.hpp index c992a02..1ce2b8e 100644 --- a/connection.hpp +++ b/connection.hpp @@ -2,7 +2,9 @@ #include "irc_command.hpp" #include "ircmsg.hpp" +#include "settings.hpp" #include "snote.hpp" +#include "stream.hpp" #include #include @@ -11,10 +13,36 @@ #include #include +template +class Ref { + struct Deleter { auto operator()(auto ptr) { Free(ptr); }}; + std::unique_ptr obj; +public: + Ref() = default; + Ref(T* t) : obj{t} { if (t) UpRef(t); } + auto get() const -> T* { return obj.get(); } +}; + +struct ConnectSettings +{ + bool tls; + std::string host; + std::uint16_t port; + + Ref client_cert; + Ref client_key; + std::string verify; + std::string sni; + + std::string socks_host; + std::uint16_t socks_port; + std::string socks_user; + std::string socks_pass; +}; class Connection : public std::enable_shared_from_this { private: - boost::asio::ip::tcp::socket stream_; + Stream stream_; boost::asio::steady_timer watchdog_timer_; std::list write_strings_; bool write_posted_; @@ -58,8 +86,7 @@ public: auto connect( boost::asio::io_context &io, - std::string host, - std::string port + ConnectSettings settings ) -> boost::asio::awaitable; auto close() -> void; diff --git a/main.cpp b/main.cpp index c3db272..09d4e96 100644 --- a/main.cpp +++ b/main.cpp @@ -2,15 +2,16 @@ #include "settings.hpp" #include +#include #include #include #include #include +#include "bot.hpp" #include "registration_thread.hpp" #include "self_thread.hpp" -#include "bot.hpp" using namespace std::chrono_literals; @@ -31,16 +32,29 @@ auto start(boost::asio::io_context &io, const Settings &settings) -> void }); boost::asio::co_spawn( - io, connection->connect(io, settings.host, settings.service), - [&io, &settings](std::exception_ptr e) { + io, connection->connect(io, ConnectSettings{ + .tls = settings.use_tls, + .host = settings.host, + .port = settings.service, + .verify = settings.tls_hostname, + }), + [&io, &settings, connection](std::exception_ptr e) { + try + { + if (e) + std::rethrow_exception(e); + } + catch (const std::exception &e) + { + BOOST_LOG_TRIVIAL(debug) << "TERMINATED: " << e.what(); + } + auto timer = std::make_shared(io); timer->expires_after(5s); timer->async_wait( - [&io, &settings, timer](auto) { start(io, settings); } - ); - } - ); + [&io, &settings, timer](auto) { start(io, settings); }); + }); } auto get_settings() -> Settings @@ -51,7 +65,7 @@ auto get_settings() -> Settings } else { - std::cerr << "Unable to open config.toml\n"; + BOOST_LOG_TRIVIAL(error) << "Unable to open config.toml"; std::exit(1); } } diff --git a/mysocks5/CMakeLists.txt b/mysocks5/CMakeLists.txt new file mode 100644 index 0000000..da1d303 --- /dev/null +++ b/mysocks5/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(mysocks5 STATIC socks5.cpp) +target_include_directories(mysocks5 PUBLIC include) +target_link_libraries(mysocks5 PUBLIC Boost::asio Boost::endian) diff --git a/mysocks5/include/socks5.hpp b/mysocks5/include/socks5.hpp new file mode 100644 index 0000000..f43bc67 --- /dev/null +++ b/mysocks5/include/socks5.hpp @@ -0,0 +1,404 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace socks5 { + +struct SocksErrCategory : boost::system::error_category +{ + char const* name() const noexcept override; + std::string message(int) const override; +}; + +extern SocksErrCategory const theSocksErrCategory; + +enum class SocksErrc +{ + // Errors from the server + Succeeded = 0, + GeneralFailure = 1, + NotAllowed = 2, + NetworkUnreachable = 3, + HostUnreachable = 4, + ConnectionRefused = 5, + TtlExpired = 6, + CommandNotSupported = 7, + AddressNotSupported = 8, + // Errors from the client + WrongVersion = 256, + NoAcceptableMethods, + AuthenticationFailed, + UnsupportedEndpointAddress, + DomainTooLong, + UsernameTooLong, + PasswordTooLong, +}; + +/// Either a hostname or an address. Hostnames are resolved locally on the proxy server +using Host = std::variant; + +struct NoCredential +{ +}; +struct UsernamePasswordCredential +{ + std::string_view username; + std::string_view password; +}; + +using Auth = std::variant; + +namespace detail { + + template + struct overloaded : Ts... + { + using Ts::operator()...; + }; + template + overloaded(Ts...) -> overloaded; + + auto make_socks_error(SocksErrc const err) -> boost::system::error_code; + + inline auto push_buffer(std::vector& buffer, auto const& thing) -> void + { + buffer.push_back(thing.size()); + buffer.insert(buffer.end(), thing.begin(), thing.end()); + } + + uint8_t const socks_version_tag = 5; + uint8_t const auth_version_tag = 1; + + enum class AuthMethod + { + NoAuth = 0, + Gssapi = 1, + UsernamePassword = 2, + NoAcceptableMethods = 255, + }; + + enum class Command + { + Connect = 1, + Bind = 2, + UdpAssociate = 3, + }; + + enum class AddressType + { + IPv4 = 1, + DomainName = 3, + IPv6 = 4, + }; + + /// @brief Encode the given host into the end of the buffer + /// @param host host to encode + /// @param buffer target to push bytes onto + /// @return true for success and false for failure + auto push_host(Host const& host, std::vector& buffer) -> void; + + template + struct SocksImplementation + { + AsyncStream& socket_; + Host const host_; + boost::endian::big_uint16_t const port_; + Auth const auth_; + + /// buffer used to back async read/write operations + std::vector buffer_; + + // Representations of states in the protocol + struct Start + { + }; + struct HelloRecvd + { + static const std::size_t READ = 2; // version, method + }; + struct AuthRecvd + { + static const std::size_t READ = 2; // subversion, status + }; + struct ReplyRecvd + { + static const std::size_t READ = 4; // version, reply, reserved, address-tag + }; + struct FinishIpv4 + { + static const std::size_t READ = 6; // ipv4 + port = 6 bytes + }; + struct FinishIpv6 + { + static const std::size_t READ = 18; // ipv6 + port = 18 bytes + }; + + /// @brief State when the application needs to receive some bytes + /// @tparam Next State to transistion to after read is successful + template + struct Sent + { + }; + + /// @brief intermediate completion callback + /// @tparam Self type of enclosing intermediate completion handler + /// @tparam State protocol state tag type + /// @param self enclosing intermediate completion handler + /// @param state protocol state tag value + /// @param error error code of read/write operation + /// @param size bytes read or written + /// @param + template + auto operator()( + Self& self, + State state = {}, + boost::system::error_code const error = {}, + std::size_t = 0 + ) -> void + { + if (error) + { + self.complete(error, {}); + } + else + { + step(self, state); + } + } + + /// @brief Write the buffer to the socket and then read N bytes back into the buffer + /// @tparam Next state to resume after read + /// @tparam N number of bytes to read + /// @tparam Self type of enclosing intermediate completion handler + /// @param self enclosing intermediate completion handler + template + auto transact(Self& self) -> void + { + boost::asio::async_write( + socket_, + boost::asio::buffer(buffer_), + [self = std::move(self)](boost::system::error_code const err, std::size_t const n) mutable { + self(Sent{}, err, n); + } + ); + } + + /// @brief Notify the caller of a failure and terminate the protocol + /// @param self intermediate completion handler + /// @param err error code to return to the caller + auto failure(auto& self, SocksErrc const err) -> void + { + self.complete(make_socks_error(err), {}); + } + + /// @brief Read bytes needed by Next state from the socket and then proceed to Next state + /// @tparam Self type of enclosing intermediate completion handler + /// @tparam Next state to transition to after read + /// @param self enclosing intermediate completion handler + /// @param state protocol state tag + template + auto step(Self& self, Sent) -> void + { + buffer_.resize(Next::READ); + boost::asio::async_read( + socket_, + boost::asio::buffer(buffer_), + [self = std::move(self)](boost::system::error_code const err, std::size_t n) mutable { + self(Next{}, err, n); + } + ); + } + + // Send hello and offer authentication methods + template + auto step(Self& self, Start) -> void + { + if (auto const* const host = std::get_if(&host_)) + { + if (host->size() >= 256) + { + return failure(self, SocksErrc::DomainTooLong); + } + } + if (auto const* const plain = std::get_if(&auth_)) + { + if (plain->username.size() >= 256) + { + return failure(self, SocksErrc::UsernameTooLong); + } + if (plain->password.size() >= 256) + { + return failure(self, SocksErrc::PasswordTooLong); + } + } + + buffer_ = {socks_version_tag, 1 /* number of methods */, static_cast(method_wanted())}; + transact(self); + } + + // Send TCP connection request for the domain name and port + template + auto step(Self& self, HelloRecvd) -> void + { + if (socks_version_tag != buffer_[0]) + { + return failure(self, SocksErrc::WrongVersion); + } + + auto const wanted = method_wanted(); + auto const selected = static_cast(buffer_[1]); + + if (AuthMethod::NoAuth == wanted && wanted == selected) + { + send_connect(self); + } + else if (AuthMethod::UsernamePassword == wanted && wanted == selected) + { + send_usernamepassword(self); + } + else + { + failure(self, SocksErrc::NoAcceptableMethods); + } + } + + /// @brief Transmit the username and password to the server + /// @tparam Self type of enclosing intermediate completion handler + /// @param self enclosing intermediate completion handler + template + auto send_usernamepassword(Self& self) -> void + { + buffer_ = { + auth_version_tag, + }; + auto const [username, password] = std::get<1>(auth_); + push_buffer(buffer_, username); + push_buffer(buffer_, password); + transact(self); + } + + template + auto step(Self& self, AuthRecvd) -> void + { + if (auth_version_tag != buffer_[0]) + { + return failure(self, SocksErrc::WrongVersion); + } + + // STATUS zero is success, non-zero is failure + if (0 != buffer_[1]) + { + return failure(self, SocksErrc::AuthenticationFailed); + } + + send_connect(self); + } + + template + auto send_connect(Self& self) -> void + { + buffer_ = { + socks_version_tag, + static_cast(Command::Connect), + 0 /* reserved */, + }; + push_host(host_, buffer_); + buffer_.insert(buffer_.end(), port_.data(), port_.data() + 2); + + transact(self); + } + + // Waiting on the remaining variable-sized address portion of the response + template + auto step(Self& self, ReplyRecvd) -> void + { + if (socks_version_tag != buffer_[0]) + { + return failure(self, SocksErrc::WrongVersion); + } + auto const reply = static_cast(buffer_[1]); + if (SocksErrc::Succeeded != reply) + { + return failure(self, reply); + } + + switch (static_cast(buffer_[3])) + { + case AddressType::IPv4: + return step(self, Sent{}); + case AddressType::IPv6: + return step(self, Sent{}); + default: + return failure(self, SocksErrc::UnsupportedEndpointAddress); + } + } + + // Protocol complete! Return the client's remote endpoint + template + void step(Self& self, FinishIpv4) + { + boost::asio::ip::address_v4::bytes_type bytes; + boost::endian::big_uint16_t port; + std::memcpy(bytes.data(), &buffer_[0], 4); + std::memcpy(port.data(), &buffer_[4], 2); + self.complete({}, {boost::asio::ip::make_address_v4(bytes), port}); + } + + // Protocol complete! Return the client's remote endpoint + template + void step(Self& self, FinishIpv6) + { + boost::asio::ip::address_v6::bytes_type bytes; + boost::endian::big_uint16_t port; + std::memcpy(bytes.data(), &buffer_[0], 16); + std::memcpy(port.data(), &buffer_[16], 2); + self.complete({}, {boost::asio::ip::make_address_v6(bytes), port}); + } + + auto method_wanted() const -> AuthMethod + { + return std::visit( + overloaded{ + [](NoCredential) { return AuthMethod::NoAuth; }, + [](UsernamePasswordCredential) { return AuthMethod::UsernamePassword; }, + }, + auth_ + ); + } + }; + +} // namespace detail + +using Signature = void(boost::system::error_code, boost::asio::ip::tcp::endpoint); + +/// @brief Asynchronous SOCKS5 connection request +/// @tparam AsyncStream Type of socket +/// @tparam CompletionToken Token accepting: error_code, address, port +/// @param socket Established connection to SOCKS5 server +/// @param host Connection target host +/// @param port Connection target port +/// @param token Completion token +/// @return Behavior determined by completion token type +template < + typename AsyncStream, + boost::asio::completion_token_for CompletionToken> +auto async_connect( + AsyncStream& socket, + Host const host, + uint16_t const port, + Auth const auth, + CompletionToken&& token +) +{ + return boost::asio::async_compose(detail::SocksImplementation{socket, host, port, auth, {}}, token, socket); +} + +} // namespace socks5 diff --git a/mysocks5/socks5.cpp b/mysocks5/socks5.cpp new file mode 100644 index 0000000..03e2d37 --- /dev/null +++ b/mysocks5/socks5.cpp @@ -0,0 +1,93 @@ +#include "socks5.hpp" + +#include + +#include +#include +#include + +namespace socks5 { + +SocksErrCategory const theSocksErrCategory; + +char const* SocksErrCategory::name() const noexcept +{ + return "socks5"; +} + +std::string SocksErrCategory::message(int ev) const +{ + switch (static_cast(ev)) + { + case SocksErrc::Succeeded: + return "succeeded"; + case SocksErrc::GeneralFailure: + return "general SOCKS server failure"; + case SocksErrc::NotAllowed: + return "connection not allowed by ruleset"; + case SocksErrc::NetworkUnreachable: + return "network unreachable"; + case SocksErrc::HostUnreachable: + return "host unreachable"; + case SocksErrc::ConnectionRefused: + return "connection refused"; + case SocksErrc::TtlExpired: + return "TTL expired"; + case SocksErrc::CommandNotSupported: + return "command not supported"; + case SocksErrc::AddressNotSupported: + return "address type not supported"; + case SocksErrc::WrongVersion: + return "bad server protocol version"; + case SocksErrc::NoAcceptableMethods: + return "server rejected authentication methods"; + case SocksErrc::AuthenticationFailed: + return "server rejected authentication"; + case SocksErrc::UnsupportedEndpointAddress: + return "server sent unknown endpoint address"; + case SocksErrc::DomainTooLong: + return "domain name too long"; + case SocksErrc::UsernameTooLong: + return "username too long"; + case SocksErrc::PasswordTooLong: + return "password too long"; + default: + return "(unrecognized error)"; + } +} + +namespace detail { + + auto make_socks_error(SocksErrc const err) -> boost::system::error_code + { + return boost::system::error_code{int(err), theSocksErrCategory}; + } + + auto push_host(Host const& host, std::vector& buffer) -> void + { + std::visit(overloaded{[&buffer](std::string_view const hostname) { + buffer.push_back(uint8_t(AddressType::DomainName)); + push_buffer(buffer, hostname); + }, + [&buffer](boost::asio::ip::address const& address) { + if (address.is_v4()) + { + buffer.push_back(uint8_t(AddressType::IPv4)); + push_buffer(buffer, address.to_v4().to_bytes()); + } + else if (address.is_v6()) + { + buffer.push_back(uint8_t(AddressType::IPv6)); + push_buffer(buffer, address.to_v6().to_bytes()); + } + else + { + throw std::logic_error{"unexpected address type"}; + } + }}, + host); + } + +} // namespace detail + +} // namespace socks5 diff --git a/settings.cpp b/settings.cpp index 47e6377..d4ee951 100644 --- a/settings.cpp +++ b/settings.cpp @@ -8,7 +8,7 @@ auto Settings::from_stream(std::istream &in) -> Settings const auto config = toml::parse(in); return Settings{ .host = config["host"].value_or(std::string{}), - .service = config["service"].value_or(std::string{}), + .service = config["port"].value_or(std::uint16_t{6667}), .password = config["password"].value_or(std::string{}), .username = config["username"].value_or(std::string{}), .realname = config["realname"].value_or(std::string{}), @@ -16,6 +16,8 @@ auto Settings::from_stream(std::istream &in) -> Settings .sasl_mechanism = config["sasl_mechanism"].value_or(std::string{}), .sasl_authcid = config["sasl_authcid"].value_or(std::string{}), .sasl_authzid = config["sasl_authzid"].value_or(std::string{}), - .sasl_password = config["sasl_password"].value_or(std::string{}) + .sasl_password = config["sasl_password"].value_or(std::string{}), + .tls_hostname = config["tls_hostname"].value_or(std::string{}), + .use_tls = config["use_tls"].value_or(false), }; } diff --git a/settings.hpp b/settings.hpp index 11159b0..f716920 100644 --- a/settings.hpp +++ b/settings.hpp @@ -6,7 +6,7 @@ struct Settings { std::string host; - std::string service; + std::uint16_t service; std::string password; std::string username; std::string realname; @@ -17,5 +17,8 @@ struct Settings std::string sasl_authzid; std::string sasl_password; + std::string tls_hostname; + bool use_tls; + static auto from_stream(std::istream &in) -> Settings; }; diff --git a/stream.hpp b/stream.hpp new file mode 100644 index 0000000..ea5c317 --- /dev/null +++ b/stream.hpp @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include +#include + +/// @brief Abstraction over plain-text and TLS streams. +class Stream : private + std::variant< + boost::asio::ip::tcp::socket, + boost::asio::ssl::stream> +{ +public: + using tcp_socket = boost::asio::ip::tcp::socket; + using tls_stream = boost::asio::ssl::stream; + + /// @brief The type of the executor associated with the stream. + using executor_type = boost::asio::any_io_executor; + + /// @brief Type of the lowest layer of this stream + using lowest_layer_type = tcp_socket::lowest_layer_type; + +private: + using base_type = std::variant; + auto base() -> base_type& { return *this; } + auto base() const -> base_type const& { return *this; } + +public: + + /// @brief Initialize stream with a plain TCP socket + /// @param ioc IO context of stream + template + Stream(T&& executor) : base_type{std::in_place_type, std::forward(executor)} {} + + /// @brief Reset stream to a plain TCP socket + /// @return Reference to internal socket object + auto reset() -> tcp_socket& + { + return base().emplace(get_executor()); + } + + /// @brief Upgrade a plain TCP socket into a TLS stream. + /// @param ctx TLS context used for handshake + /// @return Reference to internal stream object + auto upgrade(boost::asio::ssl::context& ctx) -> tls_stream& + { + auto socket = std::move(std::get(base())); + return base().emplace(std::move(socket), ctx); + } + + /// @brief Get underlying basic socket + /// @return Reference to underlying socket + auto lowest_layer() -> lowest_layer_type& + { + return std::visit([](auto&& x) -> decltype(auto) { return x.lowest_layer(); }, base()); + } + + /// @brief Get underlying basic socket + /// @return Reference to underlying socket + auto lowest_layer() const -> lowest_layer_type const& + { + return std::visit([](auto&& x) -> decltype(auto) { return x.lowest_layer(); }, base()); + } + + /// @brief Get the executor associated with this stream. + /// @return The executor associated with the stream. + auto get_executor() -> executor_type const& + { + return lowest_layer().get_executor(); + } + + /// @brief Initiates an asynchronous read operation. + /// @tparam MutableBufferSequence Type of the buffer sequence. + /// @tparam Token Type of the completion token. + /// @param buffers The buffer sequence into which data will be read. + /// @param token The completion token for the read operation. + /// @return The result determined by the completion token. + template < + typename MutableBufferSequence, + boost::asio::completion_token_for Token> + auto async_read_some(MutableBufferSequence&& buffers, Token&& token) -> decltype(auto) + { + return std::visit([&buffers, &token](auto&& x) -> decltype(auto) { + return x.async_read_some(std::forward(buffers), std::forward(token)); + }, base()); + } + + /// @brief Initiates an asynchronous write operation. + /// @tparam ConstBufferSequence Type of the buffer sequence. + /// @tparam Token Type of the completion token. + /// @param buffers The buffer sequence from which data will be written. + /// @param token The completion token for the write operation. + /// @return The result determined by the completion token. + template < + typename ConstBufferSequence, + boost::asio::completion_token_for Token> + auto async_write_some(ConstBufferSequence&& buffers, Token&& token) -> decltype(auto) + { + return std::visit([&buffers, &token](auto&& x) -> decltype(auto) { + return x.async_write_some(std::forward(buffers), std::forward(token)); + }, base()); + } + + /// @brief Tear down the network stream + auto close() -> void + { + boost::system::error_code err; + auto& socket = lowest_layer(); + socket.shutdown(socket.shutdown_both, err); + socket.lowest_layer().close(err); + } +};