#include "connection.hpp"

#include "linebuffer.hpp"

#include <mybase64.hpp>

#include <boost/log/trivial.hpp>

namespace {
#include "irc_commands.inc"

using tcp_type = boost::asio::ip::tcp::socket;
using tls_type = boost::asio::ssl::stream<tcp_type>;
} // namespace

using namespace std::literals;

Connection::Connection(boost::asio::io_context &io)
    : stream_{io}
    , watchdog_timer_{io}
    , write_posted_{false}
    , stalled_{false}
{
}

auto Connection::write_buffers() -> void
{
    std::vector<boost::asio::const_buffer> buffers;
    buffers.reserve(write_strings_.size());
    for (const auto &elt : write_strings_)
    {
        buffers.push_back(boost::asio::buffer(elt));
    }
    boost::asio::async_write(
        stream_,
        buffers,
        [this, strings = std::move(write_strings_)](const boost::system::error_code &error, std::size_t) {
            if (not error)
            {
                if (write_strings_.empty())
                {
                    write_posted_ = false;
                }
                else
                {
                    write_buffers();
                }
            }
        }
    );
    write_strings_.clear();
}

auto Connection::watchdog() -> void
{
    watchdog_timer_.expires_after(watchdog_duration);
    watchdog_timer_.async_wait([this](const auto &error) {
        if (not error)
        {
            if (stalled_)
            {
                BOOST_LOG_TRIVIAL(debug) << "Watchdog timer elapsed, closing stream";
                close();
            }
            else
            {
                send_ping("watchdog");
                stalled_ = true;
                watchdog();
            }
        }
    });
}

auto Connection::watchdog_activity() -> void
{
    stalled_ = false;
    watchdog_timer_.expires_after(watchdog_duration);
}

/// Parse IRC message line and dispatch it to the ircmsg slot.
auto Connection::dispatch_line(char *line) -> void
{
    const auto msg = parse_irc_message(line);
    const auto recognized = IrcCommandHash::in_word_set(msg.command.data(), msg.command.size());
    const auto command
        = recognized
            && recognized->min_args <= msg.args.size()
            && recognized->max_args >= msg.args.size()
        ? recognized->command
        : IrcCommand::UNKNOWN;

    switch (command)
    {

    // Respond to pings immediate and discard
    case IrcCommand::PING:
        send_pong(msg.args[0]);
        break;

    // Unknown message generate warnings but do not dispatch
    // Messages can be unknown due to bad command or bad argument count
    case IrcCommand::UNKNOWN:
        BOOST_LOG_TRIVIAL(warning) << "Unrecognized command: " << msg.command << " " << msg.args.size();
        break;

    case IrcCommand::AUTHENTICATE:
        on_authenticate(msg.args[0]);
        break;

    // Server notice generate snote events but not IRC command events
    case IrcCommand::NOTICE:
        if (auto match = snoteCore.match(msg))
        {
            sig_snote(*match);
            break;
        }
        /* FALLTHROUGH */

    // Normal IRC commands
    default:
        sig_ircmsg(command, msg);
        break;
    }
}

auto Connection::write_line(std::string message) -> void
{
    BOOST_LOG_TRIVIAL(debug) << "SEND: " << message;
    message += "\r\n";
    write_strings_.push_back(std::move(message));

    if (not write_posted_)
    {
        write_posted_ = true;
        boost::asio::post(stream_.get_executor(), [weak = weak_from_this()]() {
            if (auto self = weak.lock())
            {
                self->write_buffers();
            }
        });
    }
}

auto Connection::close() -> void
{
    stream_.close();
}

auto Connection::write_irc(std::string message) -> void
{
    write_line(std::move(message));
}

auto Connection::write_irc(std::string front, std::string_view last) -> void
{
    if (last.find_first_of("\r\n\0"sv) != last.npos)
    {
        throw std::runtime_error{"bad irc argument"};
    }

    front += " :";
    front += last;
    write_irc(std::move(front));
}

auto Connection::send_ping(std::string_view txt) -> void
{
    write_irc("PING", txt);
}

auto Connection::send_pong(std::string_view txt) -> void
{
    write_irc("PONG", txt);
}

auto Connection::send_pass(std::string_view password) -> void
{
    write_irc("PASS", password);
}

auto Connection::send_user(std::string_view user, std::string_view real) -> void
{
    write_irc("USER", user, "*", "*", real);
}

auto Connection::send_nick(std::string_view nick) -> void
{
    write_irc("NICK", nick);
}

auto Connection::send_cap_ls() -> void
{
    write_irc("CAP", "LS", "302");
}

auto Connection::send_cap_end() -> void
{
    write_irc("CAP", "END");
}

auto Connection::send_cap_req(std::string_view caps) -> void
{
    write_irc("CAP", "REQ", caps);
}

auto Connection::send_privmsg(std::string_view target, std::string_view message) -> void
{
    write_irc("PRIVMSG", target, message);
}

auto Connection::send_notice(std::string_view target, std::string_view message) -> void
{
    write_irc("NOTICE", target, message);
}

auto Connection::send_wallops(std::string_view message) -> void
{
    write_irc("WALLOPS", message);
}

auto Connection::send_names(std::string_view channel) -> void
{
    write_irc("NAMES", channel);
}

auto Connection::send_map() -> void
{
    write_irc("MAP");
}

auto Connection::send_get_topic(std::string_view channel) -> void
{
    write_irc("TOPIC", channel);
}

auto Connection::send_set_topic(std::string_view channel, std::string_view message) -> void
{
    write_irc("TOPIC", channel, message);
}

auto Connection::send_testline(std::string_view target) -> void
{
    write_irc("TESTLINE", target);
}

auto Connection::send_masktrace_gecos(std::string_view target, std::string_view gecos) -> void
{
    write_irc("MASKTRACE", target, gecos);
}

auto Connection::send_masktrace(std::string_view target) -> void
{
    write_irc("MASKTRACE", target);
}

auto Connection::send_testmask_gecos(std::string_view target, std::string_view gecos) -> void
{
    write_irc("TESTMASK", target, gecos);
}

auto Connection::send_testmask(std::string_view target) -> void
{
    write_irc("TESTMASK", target);
}

auto Connection::send_authenticate(std::string_view message) -> void
{
    write_irc("AUTHENTICATE", message);
}

auto Connection::send_join(std::string_view channel) -> void
{
    write_irc("JOIN", channel);
}

auto Connection::send_challenge(std::string_view message) -> void
{
    write_irc("CHALLENGE", message);
}

auto Connection::send_oper(std::string_view user, std::string_view pass) -> void
{
    write_irc("OPER", user, pass);
}

auto Connection::send_kick(std::string_view channel, std::string_view nick, std::string_view reason) -> void
{
    write_irc("KICK", channel, nick, reason);
}

auto Connection::send_kill(std::string_view nick, std::string_view reason) -> void
{
    write_irc("KILL", nick, reason);
}

auto Connection::send_quit(std::string_view message) -> void
{
    write_irc("QUIT", message);
}

auto Connection::send_whois(std::string_view arg1) -> void
{
    write_irc("WHOIS", arg1);
}

auto Connection::send_whois_remote(std::string_view arg1, std::string_view arg2) -> void
{
    write_irc("WHOIS", arg1, arg2);
}

auto Connection::on_authenticate(const std::string_view chunk) -> void
{
    if (chunk != "+"sv)
    {
        authenticate_buffer_ += chunk;
    }

    if (chunk.size() != 400)
    {
        std::string decoded;
        decoded.resize(mybase64::decoded_size(authenticate_buffer_.size()));
        std::size_t len;

        if (mybase64::decode(authenticate_buffer_, decoded.data(), &len))
        {
            decoded.resize(len);
            sig_authenticate(decoded);
        }
        else
        {
            BOOST_LOG_TRIVIAL(debug) << "Invalid AUTHENTICATE base64"sv;
            send_authenticate("*"sv); // abort SASL
        }

        authenticate_buffer_.clear();
    }
    else if (authenticate_buffer_.size() > 1024)
    {
        BOOST_LOG_TRIVIAL(debug) << "AUTHENTICATE buffer overflow"sv;
        authenticate_buffer_.clear();
        send_authenticate("*"sv); // abort SASL
    }
}

auto Connection::send_authenticate_abort() -> void
{
    send_authenticate("*");
}

auto Connection::send_authenticate_encoded(std::string_view body) -> void
{
    std::string encoded(mybase64::encoded_size(body.size()), 0);
    mybase64::encode(body, encoded.data());

    for (size_t lo = 0; lo < encoded.size(); lo += 400)
    {
        const auto hi = std::min(lo + 400, encoded.size());
        const std::string_view chunk{encoded.begin() + lo, encoded.begin() + hi};
        send_authenticate(chunk);
    }

    if (encoded.size() % 400 == 0)
    {
        send_authenticate("+"sv);
    }
}

static
auto set_buffer_size(tls_type& stream, std::size_t const n) -> void
{
    auto const ssl = stream.native_handle();
    BIO_set_buffer_size(SSL_get_rbio(ssl), n);
    BIO_set_buffer_size(SSL_get_wbio(ssl), n);
}

static
auto set_buffer_size(tcp_type& socket, std::size_t const n) -> void
{
    socket.set_option(tcp_type::send_buffer_size{static_cast<int>(n)});
    socket.set_option(tcp_type::receive_buffer_size{static_cast<int>(n)});
}

static
auto set_cloexec(int const fd) -> void
{
    auto const flags = fcntl(fd, F_GETFD);
    if (-1 == flags)
    {
        throw std::system_error{errno, std::generic_category(), "failed to get file descriptor flags"};
    }
    if (-1 == fcntl(fd, F_SETFD, flags | FD_CLOEXEC))
    {
        throw std::system_error{errno, std::generic_category(), "failed to set file descriptor flags"};
    }
}

template <std::size_t... Ns>
static
auto constexpr sum() -> std::size_t { return (0 + ... + Ns); }

/**
 * @brief Build's the string format required for the ALPN extension
 *
 * @tparam Ns sizes of each protocol name
 * @param protocols array of the names of the supported protocols
 * @return encoded protocol names
 */
template <std::size_t... Ns>
static
auto constexpr alpn_encode(char const (&... protocols)[Ns]) -> std::array<unsigned char, sum<Ns...>()>
{
    auto result = std::array<unsigned char, sum<Ns...>()>{};
    auto cursor = std::begin(result);
    auto const encode = [&cursor]<std::size_t N>(char const(&protocol)[N]) {
        static_assert(N > 0, "Protocol name must be null-terminated");
        static_assert(N < 256, "Protocol name too long");
        if (protocol[N - 1] != '\0')
            throw "Protocol name not null-terminated";

        // Prefixed length byte
        *cursor++ = N - 1;

        // Add string skipping null terminator
        cursor = std::copy(std::begin(protocol), std::end(protocol) - 1, cursor);
    };
    (encode(protocols), ...);
    return result;
}

/**
 * @brief Configure the TLS stream to request the IRC protocol.
 * 
 * @param stream TLS stream
 */
static
auto set_alpn(tls_type& stream) -> void
{
    auto constexpr protos = alpn_encode("irc");
    SSL_set_alpn_protos(stream.native_handle(), protos.data(), protos.size());
}

static
auto build_ssl_context(
    X509* client_cert,
    EVP_PKEY* client_key
) -> boost::asio::ssl::context
{
    boost::asio::ssl::context ssl_context{boost::asio::ssl::context::method::tls_client};
    ssl_context.set_default_verify_paths();

    if (nullptr != client_cert)
    {
        if (1 != SSL_CTX_use_certificate(ssl_context.native_handle(), client_cert))
        {
            throw std::runtime_error{"certificate file"};
        }
    }
    if (nullptr != client_key)
    {
        if (1 != SSL_CTX_use_PrivateKey(ssl_context.native_handle(), client_key))
        {
            throw std::runtime_error{"private key"};
        }
    }
    return ssl_context;
}

auto Connection::connect(
    Settings settings
) -> boost::asio::awaitable<void>
{
    using namespace std::placeholders;

    // keep connection alive while coroutine is active
    const auto self = shared_from_this();
    const size_t irc_buffer_size = 32'768;

    {
        // Name resolution
        auto resolver = boost::asio::ip::tcp::resolver{stream_.get_executor()};
        const auto endpoints = co_await resolver.async_resolve(settings.host, std::to_string(settings.port), boost::asio::use_awaitable);
    
        // Connect to the IRC server
        auto& socket = stream_.reset();
        const auto endpoint = co_await boost::asio::async_connect(socket, endpoints, boost::asio::use_awaitable);
        BOOST_LOG_TRIVIAL(debug) << "CONNECTED: " << endpoint;

        // Set socket options
        socket.set_option(boost::asio::ip::tcp::no_delay(true));
        set_buffer_size(socket, irc_buffer_size);
        set_cloexec(socket.native_handle());
    }

    if (settings.tls)
    {
        auto cxt = build_ssl_context(settings.client_cert.get(), settings.client_key.get());

        // Upgrade stream_ to use TLS and invalidate socket
        auto& stream = stream_.upgrade(cxt);

        set_buffer_size(stream, irc_buffer_size);
        set_alpn(stream);

        if (not settings.verify.empty())
        {
            stream.set_verify_mode(boost::asio::ssl::verify_peer);
            stream.set_verify_callback(boost::asio::ssl::host_name_verification(settings.verify));
        }

        if (not settings.sni.empty())
        {
            SSL_set_tlsext_host_name(stream.native_handle(), settings.sni.c_str());
        }

        co_await stream.async_handshake(stream.client, boost::asio::use_awaitable);
    }

    sig_connect();

    watchdog();

    for (LineBuffer buffer{irc_buffer_size};;)
    {
        boost::system::error_code error;
        const auto n = co_await stream_.async_read_some(buffer.get_buffer(), boost::asio::redirect_error(boost::asio::use_awaitable, error));
        if (error)
        {
            break;
        }
        buffer.add_bytes(n, [this](char *line) {
            BOOST_LOG_TRIVIAL(debug) << "RECV: " << line;
            watchdog_activity();
            dispatch_line(line);
        });
    }

    watchdog_timer_.cancel();
    stream_.close();
}

auto Connection::start(Settings settings) -> void
{
    boost::asio::co_spawn(
        stream_.get_executor(), connect(std::move(settings)),
        [self = shared_from_this()](std::exception_ptr e) {
            try
            {
                if (e)
                    std::rethrow_exception(e);

                BOOST_LOG_TRIVIAL(debug) << "DISCONNECTED";
            }
            catch (const std::exception &e)
            {
                BOOST_LOG_TRIVIAL(debug) << "TERMINATED: " << e.what();
            }

            self->sig_disconnect();

            // Disconnect all slots to avoid circular references
            self->sig_connect.disconnect_all_slots();
            self->sig_ircmsg.disconnect_all_slots();
            self->sig_disconnect.disconnect_all_slots();
            self->sig_snote.disconnect_all_slots();
            self->sig_authenticate.disconnect_all_slots();
        });
}