#include "myirc/connection.hpp" #include "myirc/linebuffer.hpp" #include #include #include #include #include #include #include #include #include namespace myirc { #include "irc_commands.inc" using tcp_type = boost::asio::ip::tcp::socket; using tls_type = boost::asio::ssl::stream; using namespace std::literals; Connection::Connection(boost::asio::io_context &io) : stream_{io} , watchdog_timer_{io} , write_posted_{false} , stalled_{false} { } auto Connection::write_buffers() -> void { const auto available = write_strings_.size(); const auto [delay, count] = rate_limit ? rate_limit->query(available) : std::pair{0ms, available}; if (delay > 0ms) { auto timer = std::make_shared(stream_.get_executor(), delay); timer->async_wait([timer, count, self = weak_from_this()](auto) { if (auto lock = self.lock()) { lock->write_buffers(count); } }); } else { write_buffers(count); } } auto Connection::write_buffers(size_t n) -> void { std::list strings; std::vector buffers; if (n == write_strings_.size()) { strings = std::move(write_strings_); write_strings_.clear(); } else { strings.splice( strings.begin(), // insert at write_strings_, // remove from write_strings_.begin(), // start removing at std::next(write_strings_.begin(), n) // stop removing at ); } buffers.reserve(n); for (const auto &elt : strings) { buffers.push_back(boost::asio::buffer(elt)); } boost::asio::async_write( stream_, buffers, [this, strings = std::move(strings)](const boost::system::error_code &error, std::size_t) { if (not error) { if (write_strings_.empty()) { write_posted_ = false; } else { write_buffers(); } } } ); } auto Connection::watchdog() -> void { watchdog_timer_.expires_after(watchdog_duration); watchdog_timer_.async_wait([this](const auto &error) { if (not error) { if (stalled_) { BOOST_LOG_TRIVIAL(debug) << "Watchdog timer elapsed, closing stream"; close(); } else { write_irc("PING", "watchdog"); stalled_ = true; watchdog(); } } }); } auto Connection::watchdog_activity() -> void { stalled_ = false; watchdog_timer_.expires_after(watchdog_duration); } /// Parse IRC message line and dispatch it to the ircmsg slot. auto Connection::dispatch_line(char *line, bool flush) -> void { const auto msg = parse_irc_message(line); const auto recognized = IrcCommandHash::in_word_set(msg.command.data(), msg.command.size()); const auto command = recognized && recognized->min_args <= msg.args.size() && recognized->max_args >= msg.args.size() ? recognized->command : IrcCommand::UNKNOWN; switch (command) { // Respond to pings immediate and discard case IrcCommand::PING: write_irc("PONG", msg.args[0]); break; // Unknown message generate warnings but do not dispatch // Messages can be unknown due to bad command or bad argument count case IrcCommand::UNKNOWN: BOOST_LOG_TRIVIAL(warning) << "Unrecognized command: " << msg.command << " " << msg.args.size(); break; // Normal IRC commands default: sig_ircmsg(command, msg, flush); break; } } auto Connection::write_line(std::string message) -> void { BOOST_LOG_TRIVIAL(debug) << "SEND: " << message; message += "\r\n"; write_strings_.push_back(std::move(message)); if (not write_posted_) { write_posted_ = true; boost::asio::post(stream_.get_executor(), [weak = weak_from_this()]() { if (auto self = weak.lock()) { self->write_buffers(); } }); } } auto Connection::close() -> void { stream_.close(); } auto Connection::write_irc(std::string message) -> void { write_line(std::move(message)); } auto Connection::write_irc(std::string front, std::string_view last) -> void { bool colon = last.starts_with(":"); for (const auto c : last) { switch (c) { case '\r': case '\n': case '\0': throw std::runtime_error{"bad irc argument"}; case ' ': colon = true; default: break; } } front += colon ? " :" : " "; front += last; write_line(std::move(front)); } static auto set_buffer_size(tls_type& stream, std::size_t const n) -> void { auto const ssl = stream.native_handle(); BIO_set_buffer_size(SSL_get_rbio(ssl), n); BIO_set_buffer_size(SSL_get_wbio(ssl), n); } static auto set_buffer_size(tcp_type& socket, std::size_t const n) -> void { socket.set_option(tcp_type::send_buffer_size{static_cast(n)}); socket.set_option(tcp_type::receive_buffer_size{static_cast(n)}); } static auto set_cloexec(int const fd) -> void { auto const flags = fcntl(fd, F_GETFD); if (-1 == flags) { throw std::system_error{errno, std::generic_category(), "failed to get file descriptor flags"}; } if (-1 == fcntl(fd, F_SETFD, flags | FD_CLOEXEC)) { throw std::system_error{errno, std::generic_category(), "failed to set file descriptor flags"}; } } template static auto constexpr sum() -> std::size_t { return (0 + ... + Ns); } /** * @brief Build's the string format required for the ALPN extension * * @tparam Ns sizes of each protocol name * @param protocols array of the names of the supported protocols * @return encoded protocol names */ template static auto constexpr alpn_encode(char const (&... protocols)[Ns]) -> std::array()> { auto result = std::array()>{}; auto cursor = std::begin(result); auto const encode = [&cursor](char const(&protocol)[N]) { static_assert(N > 0, "Protocol name must be null-terminated"); static_assert(N < 256, "Protocol name too long"); if (protocol[N - 1] != '\0') throw "Protocol name not null-terminated"; // Prefixed length byte *cursor++ = N - 1; // Add string skipping null terminator cursor = std::copy(std::begin(protocol), std::end(protocol) - 1, cursor); }; (encode(protocols), ...); return result; } /** * @brief Configure the TLS stream to request the IRC protocol. * * @param stream TLS stream */ static auto set_alpn(tls_type& stream) -> void { auto constexpr protos = alpn_encode("irc"); SSL_set_alpn_protos(stream.native_handle(), protos.data(), protos.size()); } static auto build_ssl_context( X509* client_cert, EVP_PKEY* client_key ) -> boost::asio::ssl::context { boost::asio::ssl::context ssl_context{boost::asio::ssl::context::method::tls_client}; ssl_context.set_default_verify_paths(); if (nullptr != client_cert) { if (1 != SSL_CTX_use_certificate(ssl_context.native_handle(), client_cert)) { throw std::runtime_error{"certificate file"}; } } if (nullptr != client_key) { if (1 != SSL_CTX_use_PrivateKey(ssl_context.native_handle(), client_key)) { throw std::runtime_error{"private key"}; } } return ssl_context; } static auto peer_fingerprint(X509 *cer) -> std::string { std::ostringstream os; std::vector result; EVP_MD *md_used; if (auto digest = X509_digest_sig(cer, &md_used, nullptr)) { os << EVP_MD_name(md_used) << ":" << std::hex << std::setfill('0'); EVP_MD_free(md_used); for (int i = 0; i < digest->length; ++i) { os << std::setw(2) << static_cast(digest->data[i]); } ASN1_OCTET_STRING_free(digest); } return os.str(); } auto Connection::connect( Settings settings ) -> boost::asio::awaitable { using namespace std::placeholders; // keep connection alive while coroutine is active const auto self = shared_from_this(); const size_t irc_buffer_size = 32'768; boost::asio::ip::tcp::endpoint socket_endpoint; std::optional socks_endpoint; std::string fingerprint; { // Name resolution auto resolver = boost::asio::ip::tcp::resolver{stream_.get_executor()}; const auto endpoints = co_await resolver.async_resolve(settings.host, std::to_string(settings.port), boost::asio::use_awaitable); for (auto e : endpoints) { BOOST_LOG_TRIVIAL(debug) << "DNS: " << e.endpoint(); } // Connect to the IRC server auto& socket = stream_.reset(); // If we're going to use SOCKS then the TCP connection host is actually the socks // server and then the IRC server gets passed over the SOCKS protocol auto const use_socks = not settings.socks_host.empty() && settings.socks_port != 0; if (use_socks) { std::swap(settings.host, settings.socks_host); std::swap(settings.port, settings.socks_port); } socket_endpoint = co_await boost::asio::async_connect(socket, endpoints, boost::asio::use_awaitable); BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << socket_endpoint; // Set socket options socket.set_option(boost::asio::ip::tcp::no_delay(true)); set_buffer_size(socket, irc_buffer_size); set_cloexec(socket.native_handle()); // Optionally negotiate SOCKS connection if (use_socks) { auto auth = not settings.socks_user.empty() || not settings.socks_pass.empty() ? socks5::Auth{socks5::UsernamePasswordCredential{settings.socks_user, settings.socks_pass}} : socks5::Auth{socks5::NoCredential{}}; socks_endpoint = co_await socks5::async_connect( socket, settings.socks_host, settings.socks_port, std::move(auth), boost::asio::use_awaitable ); } } if (settings.tls) { auto cxt = build_ssl_context(settings.client_cert.get(), settings.client_key.get()); // Upgrade stream_ to use TLS and invalidate socket auto& stream = stream_.upgrade(cxt); set_buffer_size(stream, irc_buffer_size); set_alpn(stream); if (not settings.verify.empty()) { stream.set_verify_mode(boost::asio::ssl::verify_peer); stream.set_verify_callback(boost::asio::ssl::host_name_verification(settings.verify)); } if (not settings.sni.empty()) { SSL_set_tlsext_host_name(stream.native_handle(), settings.sni.c_str()); } co_await stream.async_handshake(stream.client, boost::asio::use_awaitable); const auto cer = SSL_get0_peer_certificate(stream.native_handle()); fingerprint = peer_fingerprint(cer); } sig_connect(socket_endpoint, socks_endpoint, std::move(fingerprint)); watchdog(); for (LineBuffer buffer{irc_buffer_size};;) { boost::system::error_code error; auto const chunk = buffer.prepare(); if (chunk.size() == 0) break; const auto n = co_await stream_.async_read_some(chunk, boost::asio::redirect_error(boost::asio::use_awaitable, error)); if (error) { break; } buffer.commit(n); auto line = buffer.next_nonempty_line(); if (line) { watchdog_activity(); do { BOOST_LOG_TRIVIAL(debug) << "RECV: " << line; const auto next_line = buffer.next_nonempty_line(); dispatch_line(line, next_line == nullptr); line = next_line; } while (line); } buffer.shift(); } watchdog_timer_.cancel(); stream_.close(); } auto Connection::start(Settings settings) -> void { boost::asio::co_spawn( stream_.get_executor(), connect(std::move(settings)), [self = shared_from_this()](std::exception_ptr e) { try { if (e) std::rethrow_exception(e); BOOST_LOG_TRIVIAL(debug) << "DISCONNECTED"; } catch (const std::exception &e) { BOOST_LOG_TRIVIAL(debug) << "TERMINATED: " << e.what(); } // Disconnect all slots to avoid circular references self->sig_connect.disconnect_all_slots(); self->sig_ircmsg.disconnect_all_slots(); self->sig_disconnect(e); self->sig_disconnect.disconnect_all_slots(); }); } } // namespace myirc