Program Listing for File socket.hpp
↰ Return to documentation for file (sockets/include/sockets/detail/socket.hpp)
#pragma once
#include <atomic>
#include <cstdint>
#include <deque>
#include <functional>
#include <future>
#include <lib2k/non_null_owner.hpp>
#include <lib2k/synchronized.hpp>
#include <lib2k/unique_value.hpp>
#include <memory>
#include <mutex>
#include <optional>
#include <span>
#include <stdexcept>
#include <string_view>
#include <thread>
#include "address_family.hpp"
#include "address_info.hpp"
#include "message_buffer.hpp"
namespace c2k {
class TimeoutError final : public std::runtime_error {
public:
TimeoutError()
: std::runtime_error{ "operation timed out" } {}
};
class ReadError final : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
ReadError()
: std::runtime_error{ "error reading from socket" } {}
};
class SendError final : public std::runtime_error {
using std::runtime_error::runtime_error;
};
class ClientSocket;
class AbstractSocket {
public:
#ifdef _WIN32
using OsSocketHandle = std::uintptr_t;
#else
using OsSocketHandle = int;
#endif
private:
AddressInfo m_local_address_info;
AddressInfo m_remote_address_info;
protected:
UniqueValue<OsSocketHandle, void (*)(OsSocketHandle handle)> m_socket_descriptor;
explicit AbstractSocket(OsSocketHandle os_socket_handle);
public:
[[nodiscard]] std::optional<OsSocketHandle> os_socket_handle() const {
if (not m_socket_descriptor.has_value()) {
return std::nullopt;
}
return m_socket_descriptor.value();
}
AddressInfo const& local_address() && = delete;
[[nodiscard]] AddressInfo const& local_address() const& {
return m_local_address_info;
}
AddressInfo const& remote_address() && = delete;
[[nodiscard]] AddressInfo const& remote_address() const& {
return m_remote_address_info;
}
};
namespace detail {
[[nodiscard]] AbstractSocket::OsSocketHandle initialize_server_socket(
AddressFamily address_family,
std::uint16_t port
);
}
class ServerSocket final : public AbstractSocket {
friend class Sockets;
private:
std::jthread m_listen_thread;
ServerSocket(AddressFamily address_family, std::uint16_t port, std::function<void(ClientSocket)> on_connect);
public:
ServerSocket(ServerSocket&& other) noexcept = default;
ServerSocket& operator=(ServerSocket&& other) noexcept = default;
~ServerSocket();
void stop();
};
class ClientSocket final : public AbstractSocket {
friend class Sockets;
friend void server_listen(
std::stop_token const& stop_token,
OsSocketHandle listen_socket,
std::function<void(ClientSocket)> const& on_connect
);
private:
struct SendTask {
std::promise<std::size_t> promise;
std::vector<std::byte> data;
SendTask(std::promise<std::size_t> promise, std::vector<std::byte> data)
: promise{ std::move(promise) }, data{ std::move(data) } {}
};
struct ReceiveTask {
enum class Kind {
Exact,
MaxBytes,
};
std::promise<std::vector<std::byte>> promise;
std::size_t max_num_bytes;
Kind kind;
std::chrono::steady_clock::time_point end_time;
ReceiveTask(
std::promise<std::vector<std::byte>> promise,
std::size_t const max_num_bytes,
Kind const kind,
std::chrono::steady_clock::time_point const end_time
)
: promise{ std::move(promise) }, max_num_bytes{ max_num_bytes }, kind{ kind }, end_time{ end_time } {}
};
class State {
private:
NonNullOwner<std::atomic_bool> running{ make_non_null_owner<std::atomic_bool>(true) };
public:
Synchronized<std::deque<SendTask>> send_tasks{ std::deque<SendTask>{} };
Synchronized<std::deque<ReceiveTask>> receive_tasks{ std::deque<ReceiveTask>{} };
std::condition_variable_any data_received_condition_variable;
std::condition_variable_any data_sent_condition_variable;
[[nodiscard]] bool is_running() const {
return *running;
}
void stop_running() {
*running = false;
// we have to ensure that running is set to false while holding the lock, otherwise we have a
// race condition regarding the condition variables which can lead to the threads blocking indefinitely
receive_tasks.apply([this](auto const&) { *running = false; });
send_tasks.apply([this](auto const&) { *running = false; });
data_received_condition_variable.notify_one();
data_sent_condition_variable.notify_one();
}
void clear_queues();
};
static constexpr auto default_timeout =
static_cast<std::chrono::steady_clock::duration>(std::chrono::seconds{ 1 });
std::unique_ptr<State> m_shared_state{ std::make_unique<State>() };
std::jthread m_send_thread;
std::jthread m_receive_thread;
explicit ClientSocket(OsSocketHandle os_socket_handle);
ClientSocket(AddressFamily address_family, std::string const& host, std::uint16_t port);
static void keep_sending(State& state, OsSocketHandle socket);
static void keep_receiving(State& state, OsSocketHandle socket);
using Timeout = std::chrono::steady_clock::duration;
public:
ClientSocket(ClientSocket&& other) noexcept = default;
ClientSocket& operator=(ClientSocket&& other) noexcept = default;
~ClientSocket();
[[nodiscard]] bool is_connected() const {
return m_shared_state->is_running();
}
// clang-format off
[[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
std::future<std::size_t> send(std::vector<std::byte> data);
[[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
std::future<std::size_t> send(std::integral auto... values) {
auto package = MessageBuffer{};
(package << ... << values);
return send(std::move(package));
}
[[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
std::future<std::size_t> send(MessageBuffer const& package) {
return send(package.data());
}
[[nodiscard("discarding the return value may lead to the data to never be transmitted")]]
std::future<std::size_t> send(MessageBuffer&& package) {
return send(std::move(package).data());
}
// clang-format on
[[nodiscard]] std::future<std::vector<std::byte>> receive(std::size_t max_num_bytes);
[[nodiscard]] std::future<std::vector<std::byte>> receive(std::size_t max_num_bytes, Timeout timeout);
[[nodiscard]] std::future<std::vector<std::byte>> receive_exact(std::size_t num_bytes);
[[nodiscard]] std::future<std::vector<std::byte>> receive_exact(std::size_t num_bytes, Timeout timeout);
template<std::integral... Ts>
[[nodiscard]] auto receive(Timeout const timeout = default_timeout) {
static constexpr auto total_size = detail::summed_sizeof<Ts...>();
auto future = receive_exact(total_size, timeout);
return std::async(std::launch::deferred, [future = std::move(future)]() mutable {
auto message_buffer = MessageBuffer{ future.get() };
assert(message_buffer.size() == total_size);
return message_buffer.try_extract<Ts...>().value(); // should never fail since we have enough data
});
}
void close();
private:
// clang-format off
[[nodiscard]] std::future<std::vector<std::byte>> receive_implementation(
std::size_t max_num_bytes,
ReceiveTask::Kind kind,
std::optional<std::chrono::steady_clock::time_point> end_time
);
// clang-format on
[[nodiscard]] static bool process_receive_task(OsSocketHandle socket, ReceiveTask task);
[[nodiscard]] static bool process_send_task(OsSocketHandle socket, SendTask task);
};
} // namespace c2k