Program Listing for File socket.hpp

Return to documentation for file (sockets/include/sockets/detail/socket.hpp)

#pragma once

#include <atomic>
#include <cstdint>
#include <deque>
#include <functional>
#include <future>
#include <lib2k/non_null_owner.hpp>
#include <lib2k/synchronized.hpp>
#include <lib2k/unique_value.hpp>
#include <memory>
#include <mutex>
#include <optional>
#include <span>
#include <stdexcept>
#include <string_view>
#include <thread>
#include "address_family.hpp"
#include "address_info.hpp"
#include "message_buffer.hpp"

namespace c2k {
    class TimeoutError final : public std::runtime_error {
    public:
        TimeoutError()
            : std::runtime_error{ "operation timed out" } {}
    };

    class ReadError final : public std::runtime_error {
    public:
        using std::runtime_error::runtime_error;

        ReadError()
            : std::runtime_error{ "error reading from socket" } {}
    };

    class SendError final : public std::runtime_error {
        using std::runtime_error::runtime_error;
    };

    class ClientSocket;

    class AbstractSocket {
    public:
#ifdef _WIN32
        using OsSocketHandle = std::uintptr_t;
#else
        using OsSocketHandle = int;
#endif

    private:
        AddressInfo m_local_address_info;
        AddressInfo m_remote_address_info;

    protected:
        UniqueValue<OsSocketHandle, void (*)(OsSocketHandle handle)> m_socket_descriptor;

        explicit AbstractSocket(OsSocketHandle os_socket_handle);

    public:
        [[nodiscard]] std::optional<OsSocketHandle> os_socket_handle() const {
            if (not m_socket_descriptor.has_value()) {
                return std::nullopt;
            }
            return m_socket_descriptor.value();
        }

        AddressInfo const& local_address() && = delete;

        [[nodiscard]] AddressInfo const& local_address() const& {
            return m_local_address_info;
        }

        AddressInfo const& remote_address() && = delete;

        [[nodiscard]] AddressInfo const& remote_address() const& {
            return m_remote_address_info;
        }
    };

    namespace detail {
        [[nodiscard]] AbstractSocket::OsSocketHandle initialize_server_socket(
            AddressFamily address_family,
            std::uint16_t port
        );
    }

    class ServerSocket final : public AbstractSocket {
        friend class Sockets;

    private:
        std::jthread m_listen_thread;

        ServerSocket(AddressFamily address_family, std::uint16_t port, std::function<void(ClientSocket)> on_connect);

    public:
        ServerSocket(ServerSocket&& other) noexcept = default;
        ServerSocket& operator=(ServerSocket&& other) noexcept = default;

        ~ServerSocket();

        void stop();
    };

    class ClientSocket final : public AbstractSocket {
        friend class Sockets;
        friend void server_listen(
            std::stop_token const& stop_token,
            OsSocketHandle listen_socket,
            std::function<void(ClientSocket)> const& on_connect
        );

    private:
        struct SendTask {
            std::promise<std::size_t> promise;
            std::vector<std::byte> data;

            SendTask(std::promise<std::size_t> promise, std::vector<std::byte> data)
                : promise{ std::move(promise) }, data{ std::move(data) } {}
        };

        struct ReceiveTask {
            enum class Kind {
                Exact,
                MaxBytes,
            };

            std::promise<std::vector<std::byte>> promise;
            std::size_t max_num_bytes;
            Kind kind;
            std::chrono::steady_clock::time_point end_time;

            ReceiveTask(
                std::promise<std::vector<std::byte>> promise,
                std::size_t const max_num_bytes,
                Kind const kind,
                std::chrono::steady_clock::time_point const end_time
            )
                : promise{ std::move(promise) }, max_num_bytes{ max_num_bytes }, kind{ kind }, end_time{ end_time } {}
        };

        class State {
        private:
            NonNullOwner<std::atomic_bool> running{ make_non_null_owner<std::atomic_bool>(true) };

        public:
            Synchronized<std::deque<SendTask>> send_tasks{ std::deque<SendTask>{} };
            Synchronized<std::deque<ReceiveTask>> receive_tasks{ std::deque<ReceiveTask>{} };
            std::condition_variable_any data_received_condition_variable;
            std::condition_variable_any data_sent_condition_variable;

            [[nodiscard]] bool is_running() const {
                return *running;
            }

            void stop_running() {
                *running = false;
                // we have to ensure that running is set to false while holding the lock, otherwise we have a
                // race condition regarding the condition variables which can lead to the threads blocking indefinitely
                receive_tasks.apply([this](auto const&) { *running = false; });
                send_tasks.apply([this](auto const&) { *running = false; });
                data_received_condition_variable.notify_one();
                data_sent_condition_variable.notify_one();
            }

            void clear_queues();
        };

        static constexpr auto default_timeout =
            static_cast<std::chrono::steady_clock::duration>(std::chrono::seconds{ 1 });

        std::unique_ptr<State> m_shared_state{ std::make_unique<State>() };
        std::jthread m_send_thread;
        std::jthread m_receive_thread;

        explicit ClientSocket(OsSocketHandle os_socket_handle);
        ClientSocket(AddressFamily address_family, std::string const& host, std::uint16_t port);

        static void keep_sending(State& state, OsSocketHandle socket);
        static void keep_receiving(State& state, OsSocketHandle socket);

        using Timeout = std::chrono::steady_clock::duration;

    public:
        ClientSocket(ClientSocket&& other) noexcept = default;
        ClientSocket& operator=(ClientSocket&& other) noexcept = default;

        ~ClientSocket();

        [[nodiscard]] bool is_connected() const {
            return m_shared_state->is_running();
        }

        // clang-format off
        [[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
        std::future<std::size_t> send(std::vector<std::byte> data);

        [[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
        std::future<std::size_t> send(std::integral auto... values) {
            auto package = MessageBuffer{};
            (package << ... << values);
            return send(std::move(package));
        }

        [[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
        std::future<std::size_t> send(MessageBuffer const& package) {
            return send(package.data());
        }

        [[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
        std::future<std::size_t> send(MessageBuffer&& package) {
            return send(std::move(package).data());
        }

        // clang-format on

        [[nodiscard]] std::future<std::vector<std::byte>> receive(std::size_t max_num_bytes);

        [[nodiscard]] std::future<std::vector<std::byte>> receive(std::size_t max_num_bytes, Timeout timeout);

        [[nodiscard]] std::future<std::vector<std::byte>> receive_exact(std::size_t num_bytes);

        [[nodiscard]] std::future<std::vector<std::byte>> receive_exact(std::size_t num_bytes, Timeout timeout);

        template<std::integral... Ts>
        [[nodiscard]] auto receive(Timeout const timeout = default_timeout) {
            static constexpr auto total_size = detail::summed_sizeof<Ts...>();
            auto future = receive_exact(total_size, timeout);
            return std::async(std::launch::deferred, [future = std::move(future)]() mutable {
                auto message_buffer = MessageBuffer{ future.get() };
                assert(message_buffer.size() == total_size);
                return message_buffer.try_extract<Ts...>().value();  // should never fail since we have enough data
            });
        }

        void close();

    private:
        // clang-format off
        [[nodiscard]] std::future<std::vector<std::byte>> receive_implementation(
            std::size_t max_num_bytes,
            ReceiveTask::Kind kind,
            std::optional<std::chrono::steady_clock::time_point> end_time
        );
        // clang-format on
        [[nodiscard]] static bool process_receive_task(OsSocketHandle socket, ReceiveTask task);
        [[nodiscard]] static bool process_send_task(OsSocketHandle socket, SendTask task);
    };
}  // namespace c2k