diff --git a/.github/workflows/compiler-support.yml b/.github/workflows/compiler-support.yml index 50dafd3..5c94b46 100644 --- a/.github/workflows/compiler-support.yml +++ b/.github/workflows/compiler-support.yml @@ -21,8 +21,8 @@ jobs: - { tag: "ubuntu-2004_clang-11", name: "Ubuntu 20.04 Clang 11", cxx: "/usr/bin/clang++-11", cc: "/usr/bin/clang-11", runs-on: "ubuntu-20.04" } - { tag: "ubuntu-2004_clang-10", name: "Ubuntu 20.04 Clang 10", cxx: "/usr/bin/clang++-10", cc: "/usr/bin/clang-10", runs-on: "ubuntu-20.04" } - { tag: "ubuntu-2004_gcc-10", name: "Ubuntu 20.04 G++ 10", cxx: "/usr/bin/g++-10", cc: "/usr/bin/gcc-10", runs-on: "ubuntu-20.04" } - #- { tag: "windows-2022_msvc17", name: "Windows Server 2022 MSVC 17", cxx: "", cc: "", runs-on: "windows-2022" } - #- { tag: "windows-2019_msvc16", name: "Windows Server 2019 MSVC 16", cxx: "", cc: "", runs-on: "windows-2019" } + - { tag: "windows-2022_msvc17", name: "Windows Server 2022 MSVC 17", cxx: "", cc: "", runs-on: "windows-2022" } + - { tag: "windows-2019_msvc16", name: "Windows Server 2019 MSVC 16", cxx: "", cc: "", runs-on: "windows-2019" } - { tag: "macos-12_gcc-12", name: "MacOS 12 G++ 12", cxx: "g++-12", cc: "gcc-12", runs-on: "macos-12" } #- { tag: "macos-12_gcc-13", name: "MacOS 12 G++ 13", cxx: "g++-13", cc: "gcc-13", runs-on: "macos-12" } - { tag: "macos-12_gcc-14", name: "MacOS 12 G++ 14", cxx: "g++-14", cc: "gcc-14", runs-on: "macos-12" } diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b5a12d..e6a04e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,13 +34,18 @@ add_library( ${CMAKE_CURRENT_SOURCE_DIR}/src/address.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/dns.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/file.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine_generic_unix.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine_select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine_uring.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine_iocp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/io_engine.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/io_service.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/socket.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tls.cpp) target_link_libraries(asyncpp_io PUBLIC asyncpp OpenSSL::SSL Threads::Threads) +if(WIN32) + target_link_libraries(asyncpp_io PUBLIC wsock32 ws2_32 ntdll) +endif() target_include_directories(asyncpp_io PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_compile_features(asyncpp_io PUBLIC cxx_std_20) diff --git a/include/asyncpp/io/address.h b/include/asyncpp/io/address.h index 91af18f..8a5aeda 100644 --- a/include/asyncpp/io/address.h +++ b/include/asyncpp/io/address.h @@ -94,7 +94,7 @@ namespace asyncpp::io { constexpr auto parse_part = [](std::string_view::const_iterator& it, std::string_view::const_iterator end) { if (it == end || (*it < '0' && *it > '9')) return -1; int32_t result = 0; - while (*it >= '0' && *it <= '9') { + while (it != end && *it >= '0' && *it <= '9') { result = (result * 10) + (*it - '0'); it++; } @@ -157,7 +157,7 @@ namespace asyncpp::io { constexpr std::span data() const noexcept { return m_data; } constexpr std::span ipv4_data() const noexcept { - return std::span{&m_data[12], &m_data[16]}; + return std::span{&m_data[12], &m_data[12] + 4}; } constexpr uint64_t subnet_prefix() const noexcept { @@ -189,7 +189,7 @@ namespace asyncpp::io { } constexpr ipv4_address mapped_ipv4() const noexcept { if (!is_ipv4_mapped()) return ipv4_address(); - return ipv4_address(std::span(&m_data[12], &m_data[16])); + return ipv4_address(std::span(&m_data[12], &m_data[12] + 4)); } std::string to_string(bool full = false) const { @@ -249,7 +249,7 @@ namespace asyncpp::io { auto it = str.begin(); auto part_start = it; bool is_v4_interop = false; - if (*it == ':') { + if (it != str.end() && *it == ':') { dcidx = idx++; it++; if (it == str.end() || *it != ':') return std::nullopt; @@ -508,7 +508,7 @@ namespace std { template<> struct hash { size_t operator()(const asyncpp::io::uds_address& x) const noexcept { - size_t res = 0; + size_t res{}; for (auto e : x.data()) res = res ^ (e + 0x9e3779b99e3779b9ull + (res << 6) + (res >> 2)); return res; @@ -518,7 +518,7 @@ namespace std { template<> struct hash { size_t operator()(const asyncpp::io::address& x) const noexcept { - size_t res; + size_t res{}; switch (x.type()) { case asyncpp::io::address_type::ipv4: res = std::hash{}(x.ipv4()); break; case asyncpp::io::address_type::ipv6: res = std::hash{}(x.ipv6()); break; diff --git a/include/asyncpp/io/detail/cancel_awaitable.h b/include/asyncpp/io/detail/cancel_awaitable.h index 048380d..d506a4d 100644 --- a/include/asyncpp/io/detail/cancel_awaitable.h +++ b/include/asyncpp/io/detail/cancel_awaitable.h @@ -27,7 +27,7 @@ namespace asyncpp::io::detail { bool await_ready() const noexcept { return m_child.await_ready(); } bool await_suspend(coroutine_handle<> hdl) { if (m_stop_token.stop_requested()) { - m_child.m_completion.result = -ECANCELED; + m_child.m_completion.result = std::make_error_code(std::errc::operation_canceled); return false; } auto res = m_child.await_suspend(hdl); diff --git a/include/asyncpp/io/detail/io_engine.h b/include/asyncpp/io/detail/io_engine.h index 5c7e9a3..eec102c 100644 --- a/include/asyncpp/io/detail/io_engine.h +++ b/include/asyncpp/io/detail/io_engine.h @@ -1,27 +1,62 @@ #pragma once #include +#include +#include #include +#include namespace asyncpp::io::detail { class io_engine { public: +#ifndef _WIN32 using file_handle_t = int; constexpr static file_handle_t invalid_file_handle = -1; using socket_handle_t = int; constexpr static socket_handle_t invalid_socket_handle = -1; +#else + using file_handle_t = void*; + constexpr static file_handle_t invalid_file_handle = reinterpret_cast(static_cast(-1)); + using socket_handle_t = unsigned long long; + constexpr static socket_handle_t invalid_socket_handle = ~static_cast(0); + +#endif enum class fsync_flags { none, datasync }; + enum class socket_type { stream, dgram, seqpacket }; struct completion_data { + completion_data(void (*cb)(void*) = nullptr, void* udata = nullptr) noexcept + : callback(cb), userdata(udata) {} + + // Private data the engine can use to associate state + alignas(std::max_align_t) std::array engine_state{}; + // Info provided by caller - void (*callback)(void*); - void* userdata; + void (*callback)(void*){}; + void* userdata{}; // Filled by io_engine - int result; + std::error_code result{}; + union { + socket_handle_t result_handle{}; + size_t result_size; + }; - // Private data the engine can use to associate state - void* engine_state{}; + template + T* es_init() noexcept { + static_assert(std::is_standard_layout_v && std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + static_assert(sizeof(T) <= std::tuple_size_v); + engine_state.fill(std::byte{}); + return new (engine_state.data()) T(); + } + template + T* es_get() noexcept { + static_assert(std::is_standard_layout_v && std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + static_assert(sizeof(T) <= std::tuple_size_v); + return reinterpret_cast(engine_state.data()); + } }; public: @@ -33,6 +68,18 @@ namespace asyncpp::io::detail { virtual void wake() = 0; // Networking api + virtual socket_handle_t socket_create(address_type domain, socket_type type) = 0; + virtual std::pair socket_create_connected_pair(address_type domain, + socket_type type) = 0; + virtual void socket_register(socket_handle_t socket) = 0; + virtual void socket_release(socket_handle_t socket) = 0; + virtual void socket_close(socket_handle_t socket) = 0; + virtual void socket_bind(socket_handle_t socket, endpoint ep) = 0; + virtual void socket_listen(socket_handle_t socket, size_t backlog) = 0; + virtual endpoint socket_local_endpoint(socket_handle_t socket) = 0; + virtual endpoint socket_remote_endpoint(socket_handle_t socket) = 0; + virtual void socket_enable_broadcast(socket_handle_t socket, bool enable) = 0; + virtual void socket_shutdown(socket_handle_t socket, bool receive, bool send) = 0; virtual bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) = 0; virtual bool enqueue_accept(socket_handle_t socket, completion_data* cd) = 0; virtual bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) = 0; @@ -43,8 +90,13 @@ namespace asyncpp::io::detail { completion_data* cd) = 0; // Filesystem IO - virtual bool enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) = 0; - virtual bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, + virtual file_handle_t file_open(const char* filename, std::ios_base::openmode mode) = 0; + virtual void file_register(file_handle_t fd) = 0; + virtual void file_release(file_handle_t fd) = 0; + virtual void file_close(file_handle_t fd) = 0; + virtual uint64_t file_size(file_handle_t fd) = 0; + virtual bool enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, completion_data* cd) = 0; + virtual bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, completion_data* cd) = 0; virtual bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) = 0; diff --git a/include/asyncpp/io/endpoint.h b/include/asyncpp/io/endpoint.h index 70dea3c..7887b67 100644 --- a/include/asyncpp/io/endpoint.h +++ b/include/asyncpp/io/endpoint.h @@ -108,10 +108,12 @@ namespace asyncpp::io { m_ipv6 = {addr.ipv6(), port}; m_type = address_type::ipv6; break; +#ifndef _WIN32 case address_type::uds: m_uds = addr.uds(); m_type = address_type::uds; break; +#endif } } explicit constexpr endpoint(ipv4_endpoint ep) noexcept : m_ipv4(ep), m_type(address_type::ipv4) {} @@ -128,15 +130,19 @@ namespace asyncpp::io { constexpr ipv4_endpoint ipv4() const noexcept { switch (m_type) { case address_type::ipv4: return m_ipv4; - case address_type::ipv6: + case address_type::ipv6: return {}; +#ifndef _WIN32 case address_type::uds: return {}; +#endif } } constexpr ipv6_endpoint ipv6() const noexcept { switch (m_type) { case address_type::ipv4: return {}; case address_type::ipv6: return m_ipv6; +#ifndef _WIN32 case address_type::uds: return {}; +#endif } } #ifndef _WIN32 @@ -213,7 +219,7 @@ namespace std { template<> struct hash { size_t operator()(const asyncpp::io::endpoint& x) const noexcept { - size_t res; + size_t res{}; switch (x.type()) { case asyncpp::io::address_type::ipv4: res = std::hash{}(x.ipv4()); break; case asyncpp::io::address_type::ipv6: res = std::hash{}(x.ipv6()); break; diff --git a/include/asyncpp/io/file.h b/include/asyncpp/io/file.h index 91117a8..2edccf7 100644 --- a/include/asyncpp/io/file.h +++ b/include/asyncpp/io/file.h @@ -102,8 +102,8 @@ namespace asyncpp::io { detail::io_engine::completion_data m_completion; public: - constexpr file_read_awaitable(io_engine* engine, io_engine::file_handle_t fd, void* buf, size_t len, - uint64_t offset, std::error_code* ec) noexcept + file_read_awaitable(io_engine* engine, io_engine::file_handle_t fd, void* buf, size_t len, uint64_t offset, + std::error_code* ec) noexcept : m_engine(engine), m_fd(fd), m_buf(buf), m_len(len), m_offset(offset), m_ec(ec), m_completion{} {} bool await_ready() const noexcept { return false; } bool await_suspend(coroutine_handle<> hdl) { @@ -112,10 +112,9 @@ namespace asyncpp::io { return !m_engine->enqueue_readv(m_fd, m_buf, m_len, m_offset, &m_completion); } size_t await_resume() { - if (m_completion.result >= 0) return static_cast(m_completion.result); - if (m_ec == nullptr) - throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return m_completion.result_size; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; return 0; } }; @@ -140,8 +139,8 @@ namespace asyncpp::io { detail::io_engine::completion_data m_completion; public: - constexpr file_write_awaitable(io_engine* engine, io_engine::file_handle_t fd, const void* buf, size_t len, - uint64_t offset, std::error_code* ec) noexcept + file_write_awaitable(io_engine* engine, io_engine::file_handle_t fd, const void* buf, size_t len, + uint64_t offset, std::error_code* ec) noexcept : m_engine(engine), m_fd(fd), m_buf(buf), m_len(len), m_offset(offset), m_ec(ec), m_completion{} {} bool await_ready() const noexcept { return false; } bool await_suspend(coroutine_handle<> hdl) { @@ -150,10 +149,9 @@ namespace asyncpp::io { return !m_engine->enqueue_writev(m_fd, m_buf, m_len, m_offset, &m_completion); } size_t await_resume() { - if (m_completion.result >= 0) return static_cast(m_completion.result); - if (m_ec == nullptr) - throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return m_completion.result_size; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; return 0; } }; @@ -175,7 +173,7 @@ namespace asyncpp::io { detail::io_engine::completion_data m_completion; public: - constexpr file_fsync_awaitable(io_engine* engine, io_engine::file_handle_t fd, std::error_code* ec) noexcept + file_fsync_awaitable(io_engine* engine, io_engine::file_handle_t fd, std::error_code* ec) noexcept : m_engine(engine), m_fd(fd), m_ec(ec), m_completion{} {} bool await_ready() const noexcept { return false; } bool await_suspend(coroutine_handle<> hdl) { @@ -184,10 +182,9 @@ namespace asyncpp::io { return !m_engine->enqueue_fsync(m_fd, io_engine::fsync_flags::none, &m_completion); } void await_resume() { - if (m_completion.result >= 0) return; - if (m_ec == nullptr) - throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; } }; } // namespace detail @@ -269,7 +266,7 @@ namespace asyncpp::io { file(const file&) = delete; file(file&&) noexcept; file& operator=(const file&) = delete; - file& operator=(file&&); + file& operator=(file&&) noexcept; ~file(); [[nodiscard]] io_service& service() const noexcept { return *m_io; } diff --git a/include/asyncpp/io/socket.h b/include/asyncpp/io/socket.h index 45079e8..a59c4ec 100644 --- a/include/asyncpp/io/socket.h +++ b/include/asyncpp/io/socket.h @@ -28,7 +28,7 @@ namespace asyncpp::io { detail::io_engine::completion_data m_completion; public: - constexpr socket_awaitable_base(socket& sock) noexcept : m_socket{sock}, m_completion{} {} + socket_awaitable_base(socket& sock) noexcept : m_socket{sock}, m_completion{} {} bool await_ready() const noexcept { return false; } }; } // namespace detail @@ -76,10 +76,8 @@ namespace asyncpp::io { [[nodiscard]] static socket create_and_bind_tcp(io_service& io, const endpoint& ep); [[nodiscard]] static socket create_and_bind_udp(io_service& io, const endpoint& ep); [[nodiscard]] static socket from_fd(io_service& io, detail::io_engine::socket_handle_t fd); -#ifndef __WIN32 [[nodiscard]] static std::pair connected_pair_tcp(io_service& io, address_type addrtype); [[nodiscard]] static std::pair connected_pair_udp(io_service& io, address_type addrtype); -#endif constexpr socket() noexcept = default; socket(socket&& other) noexcept; @@ -102,32 +100,32 @@ namespace asyncpp::io { [[nodiscard]] detail::io_engine::socket_handle_t native_handle() const noexcept { return m_fd; } [[nodiscard]] detail::io_engine::socket_handle_t release() noexcept { + if (m_io != nullptr && m_fd != detail::io_engine::invalid_socket_handle) + m_io->engine()->socket_release(m_fd); m_io = nullptr; m_remote_ep = {}; m_local_ep = {}; - return std::exchange(m_fd, -1); + return std::exchange(m_fd, detail::io_engine::invalid_socket_handle); } - [[nodiscard]] constexpr socket_connect_awaitable connect(const endpoint& ep) noexcept; - [[nodiscard]] constexpr socket_connect_awaitable connect(const endpoint& ep, std::error_code& ec) noexcept; - [[nodiscard]] constexpr socket_accept_awaitable accept() noexcept; - [[nodiscard]] constexpr socket_accept_error_code_awaitable accept(std::error_code& ec) noexcept; - [[nodiscard]] constexpr socket_send_awaitable send(const void* buffer, std::size_t size) noexcept; - [[nodiscard]] constexpr socket_send_awaitable send(const void* buffer, std::size_t size, - std::error_code& ec) noexcept; - [[nodiscard]] constexpr socket_recv_awaitable recv(void* buffer, std::size_t size) noexcept; - [[nodiscard]] constexpr socket_recv_awaitable recv(void* buffer, std::size_t size, + [[nodiscard]] socket_connect_awaitable connect(const endpoint& ep) noexcept; + [[nodiscard]] socket_connect_awaitable connect(const endpoint& ep, std::error_code& ec) noexcept; + [[nodiscard]] socket_accept_awaitable accept() noexcept; + [[nodiscard]] socket_accept_error_code_awaitable accept(std::error_code& ec) noexcept; + [[nodiscard]] socket_send_awaitable send(const void* buffer, std::size_t size) noexcept; + [[nodiscard]] socket_send_awaitable send(const void* buffer, std::size_t size, std::error_code& ec) noexcept; + [[nodiscard]] socket_recv_awaitable recv(void* buffer, std::size_t size) noexcept; + [[nodiscard]] socket_recv_awaitable recv(void* buffer, std::size_t size, std::error_code& ec) noexcept; + [[nodiscard]] socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size) noexcept; + [[nodiscard]] socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size, + std::error_code& ec) noexcept; + [[nodiscard]] socket_send_to_awaitable send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep) noexcept; + [[nodiscard]] socket_send_to_awaitable send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, + std::error_code& ec) noexcept; + [[nodiscard]] socket_recv_from_awaitable recv_from(void* buffer, std::size_t size) noexcept; + [[nodiscard]] socket_recv_from_awaitable recv_from(void* buffer, std::size_t size, std::error_code& ec) noexcept; - [[nodiscard]] constexpr socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size) noexcept; - [[nodiscard]] constexpr socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size, - std::error_code& ec) noexcept; - [[nodiscard]] constexpr socket_send_to_awaitable send_to(const void* buffer, std::size_t size, - const endpoint& dst_ep) noexcept; - [[nodiscard]] constexpr socket_send_to_awaitable send_to(const void* buffer, std::size_t size, - const endpoint& dst_ep, std::error_code& ec) noexcept; - [[nodiscard]] constexpr socket_recv_from_awaitable recv_from(void* buffer, std::size_t size) noexcept; - [[nodiscard]] constexpr socket_recv_from_awaitable recv_from(void* buffer, std::size_t size, - std::error_code& ec) noexcept; [[nodiscard]] socket_connect_cancellable_awaitable connect(const endpoint& ep, asyncpp::stop_token st) noexcept; [[nodiscard]] socket_connect_cancellable_awaitable connect(const endpoint& ep, asyncpp::stop_token st, @@ -204,7 +202,7 @@ namespace asyncpp::io { std::error_code* const m_ec; public: - constexpr socket_connect_awaitable(socket& sock, endpoint ep, std::error_code* ec = nullptr) noexcept + socket_connect_awaitable(socket& sock, endpoint ep, std::error_code* ec = nullptr) noexcept : socket_awaitable_base{sock}, m_ep{ep}, m_ec{ec} {} bool await_suspend(coroutine_handle<> hdl); void await_resume(); @@ -246,8 +244,8 @@ namespace asyncpp::io { std::error_code* const m_ec; public: - constexpr socket_send_awaitable(socket& sock, const void* buffer, std::size_t size, - std::error_code* ec = nullptr) noexcept + socket_send_awaitable(socket& sock, const void* buffer, std::size_t size, + std::error_code* ec = nullptr) noexcept : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_ec{ec} {} bool await_suspend(coroutine_handle<> hdl); void await_resume(); @@ -259,8 +257,7 @@ namespace asyncpp::io { std::error_code* const m_ec; public: - constexpr socket_recv_awaitable(socket& sock, void* buffer, std::size_t size, - std::error_code* ec = nullptr) noexcept + socket_recv_awaitable(socket& sock, void* buffer, std::size_t size, std::error_code* ec = nullptr) noexcept : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_ec{ec} {} bool await_suspend(coroutine_handle<> hdl); size_t await_resume(); @@ -274,8 +271,8 @@ namespace asyncpp::io { std::error_code* const m_ec; public: - constexpr socket_recv_exact_awaitable(asyncpp::io::socket& sock, void* buffer, std::size_t size, - std::error_code* ec = nullptr) noexcept + socket_recv_exact_awaitable(asyncpp::io::socket& sock, void* buffer, std::size_t size, + std::error_code* ec = nullptr) noexcept : socket_awaitable_base{sock}, m_buffer{static_cast(buffer)}, m_size{size}, m_remaining{size}, m_ec{ec} {} bool await_suspend(asyncpp::coroutine_handle<> hdl); @@ -284,7 +281,7 @@ namespace asyncpp::io { class socket_accept_awaitable : public detail::socket_awaitable_base { public: - constexpr socket_accept_awaitable(socket& sock) noexcept : socket_awaitable_base{sock} {} + socket_accept_awaitable(socket& sock) noexcept : socket_awaitable_base{sock} {} bool await_suspend(coroutine_handle<> hdl); socket await_resume(); }; @@ -293,7 +290,7 @@ namespace asyncpp::io { std::error_code& m_ec; public: - constexpr socket_accept_error_code_awaitable(socket& sock, std::error_code& ec) noexcept + socket_accept_error_code_awaitable(socket& sock, std::error_code& ec) noexcept : socket_awaitable_base{sock}, m_ec{ec} {} bool await_suspend(coroutine_handle<> hdl); std::optional await_resume(); @@ -306,8 +303,8 @@ namespace asyncpp::io { std::error_code* const m_ec; public: - constexpr socket_send_to_awaitable(socket& sock, const void* buffer, std::size_t size, endpoint dst, - std::error_code* ec = nullptr) noexcept + socket_send_to_awaitable(socket& sock, const void* buffer, std::size_t size, endpoint dst, + std::error_code* ec = nullptr) noexcept : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_destination{dst}, m_ec{ec} {} bool await_suspend(coroutine_handle<> hdl); size_t await_resume(); @@ -320,75 +317,69 @@ namespace asyncpp::io { std::error_code* const m_ec; public: - constexpr socket_recv_from_awaitable(socket& sock, void* buffer, std::size_t size, - std::error_code* ec = nullptr) noexcept + socket_recv_from_awaitable(socket& sock, void* buffer, std::size_t size, std::error_code* ec = nullptr) noexcept : socket_awaitable_base{sock}, m_buffer{buffer}, m_size{size}, m_ec{ec} {} bool await_suspend(coroutine_handle<> hdl); std::pair await_resume(); }; - [[nodiscard]] inline constexpr socket_connect_awaitable socket::connect(const endpoint& ep) noexcept { + [[nodiscard]] inline socket_connect_awaitable socket::connect(const endpoint& ep) noexcept { return socket_connect_awaitable(*this, ep); } - [[nodiscard]] inline constexpr socket_connect_awaitable socket::connect(const endpoint& ep, - std::error_code& ec) noexcept { + [[nodiscard]] inline socket_connect_awaitable socket::connect(const endpoint& ep, std::error_code& ec) noexcept { return socket_connect_awaitable(*this, ep, &ec); } - [[nodiscard]] inline constexpr socket_accept_awaitable socket::accept() noexcept { - return socket_accept_awaitable(*this); - } + [[nodiscard]] inline socket_accept_awaitable socket::accept() noexcept { return socket_accept_awaitable(*this); } - [[nodiscard]] inline constexpr socket_accept_error_code_awaitable socket::accept(std::error_code& ec) noexcept { + [[nodiscard]] inline socket_accept_error_code_awaitable socket::accept(std::error_code& ec) noexcept { return socket_accept_error_code_awaitable(*this, ec); } - [[nodiscard]] inline constexpr socket_send_awaitable socket::send(const void* buffer, std::size_t size) noexcept { + [[nodiscard]] inline socket_send_awaitable socket::send(const void* buffer, std::size_t size) noexcept { return socket_send_awaitable(*this, buffer, size); } - [[nodiscard]] inline constexpr socket_send_awaitable socket::send(const void* buffer, std::size_t size, - std::error_code& ec) noexcept { + [[nodiscard]] inline socket_send_awaitable socket::send(const void* buffer, std::size_t size, + std::error_code& ec) noexcept { return socket_send_awaitable(*this, buffer, size, &ec); } - [[nodiscard]] inline constexpr socket_recv_awaitable socket::recv(void* buffer, std::size_t size) noexcept { + [[nodiscard]] inline socket_recv_awaitable socket::recv(void* buffer, std::size_t size) noexcept { return socket_recv_awaitable(*this, buffer, size); } - [[nodiscard]] inline constexpr socket_recv_awaitable socket::recv(void* buffer, std::size_t size, - std::error_code& ec) noexcept { + [[nodiscard]] inline socket_recv_awaitable socket::recv(void* buffer, std::size_t size, + std::error_code& ec) noexcept { return socket_recv_awaitable(*this, buffer, size, &ec); } - [[nodiscard]] inline constexpr socket_recv_exact_awaitable socket::recv_exact(void* buffer, - std::size_t size) noexcept { + [[nodiscard]] inline socket_recv_exact_awaitable socket::recv_exact(void* buffer, std::size_t size) noexcept { return socket_recv_exact_awaitable(*this, buffer, size); } - [[nodiscard]] inline constexpr socket_recv_exact_awaitable socket::recv_exact(void* buffer, std::size_t size, - std::error_code& ec) noexcept { + [[nodiscard]] inline socket_recv_exact_awaitable socket::recv_exact(void* buffer, std::size_t size, + std::error_code& ec) noexcept { return socket_recv_exact_awaitable(*this, buffer, size, &ec); } - [[nodiscard]] inline constexpr socket_send_to_awaitable socket::send_to(const void* buffer, std::size_t size, - const endpoint& dst_ep) noexcept { + [[nodiscard]] inline socket_send_to_awaitable socket::send_to(const void* buffer, std::size_t size, + const endpoint& dst_ep) noexcept { return socket_send_to_awaitable(*this, buffer, size, dst_ep); } - [[nodiscard]] inline constexpr socket_send_to_awaitable + [[nodiscard]] inline socket_send_to_awaitable socket::send_to(const void* buffer, std::size_t size, const endpoint& dst_ep, std::error_code& ec) noexcept { return socket_send_to_awaitable(*this, buffer, size, dst_ep, &ec); } - [[nodiscard]] inline constexpr socket_recv_from_awaitable socket::recv_from(void* buffer, - std::size_t size) noexcept { + [[nodiscard]] inline socket_recv_from_awaitable socket::recv_from(void* buffer, std::size_t size) noexcept { return socket_recv_from_awaitable(*this, buffer, size); } - [[nodiscard]] inline constexpr socket_recv_from_awaitable socket::recv_from(void* buffer, std::size_t size, - std::error_code& ec) noexcept { + [[nodiscard]] inline socket_recv_from_awaitable socket::recv_from(void* buffer, std::size_t size, + std::error_code& ec) noexcept { return socket_recv_from_awaitable(*this, buffer, size, &ec); } @@ -470,9 +461,9 @@ namespace asyncpp::io { } inline void socket_connect_awaitable::await_resume() { - if (m_completion.result >= 0) return; - if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; } inline bool socket_send_awaitable::await_suspend(coroutine_handle<> hdl) { @@ -482,9 +473,9 @@ namespace asyncpp::io { } inline void socket_send_awaitable::await_resume() { - if (m_completion.result >= 0) return; - if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; } inline bool socket_recv_awaitable::await_suspend(coroutine_handle<> hdl) { @@ -494,9 +485,9 @@ namespace asyncpp::io { } inline size_t socket_recv_awaitable::await_resume() { - if (m_completion.result >= 0) return static_cast(m_completion.result); - if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return m_completion.result_size; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; return 0; } @@ -505,12 +496,12 @@ namespace asyncpp::io { auto that = static_cast(ptr); auto engine = that->m_socket.service().engine(); do { - if (that->m_completion.result <= 0) { + if (that->m_completion.result) { that->m_handle.resume(); break; } - that->m_buffer += that->m_completion.result; - that->m_remaining -= that->m_completion.result; + that->m_buffer += that->m_completion.result_size; + that->m_remaining -= that->m_completion.result_size; if (that->m_remaining == 0) { that->m_handle.resume(); break; @@ -522,18 +513,18 @@ namespace asyncpp::io { m_handle = hdl; auto engine = m_socket.service().engine(); while (engine->enqueue_recv(m_socket.native_handle(), m_buffer, m_remaining, &m_completion)) { - if (m_completion.result <= 0) return false; - m_buffer += m_completion.result; - m_remaining -= m_completion.result; + if (m_completion.result) return false; + m_buffer += m_completion.result_size; + m_remaining -= m_completion.result_size; if (m_remaining == 0) return false; } return true; } inline size_t socket_recv_exact_awaitable::await_resume() { - if (m_completion.result >= 0) return m_size - m_remaining; - if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return m_size - m_remaining; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; return m_size - m_remaining; } @@ -544,9 +535,8 @@ namespace asyncpp::io { } inline socket socket_accept_awaitable::await_resume() { - if (m_completion.result < 0) - throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - return socket::from_fd(m_socket.service(), m_completion.result); + if (!m_completion.result) return socket::from_fd(m_socket.service(), m_completion.result_handle); + throw std::system_error(m_completion.result); } inline bool socket_accept_error_code_awaitable::await_suspend(coroutine_handle<> hdl) { @@ -556,8 +546,8 @@ namespace asyncpp::io { } inline std::optional socket_accept_error_code_awaitable::await_resume() { - if (m_completion.result >= 0) return socket::from_fd(m_socket.service(), m_completion.result); - m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return socket::from_fd(m_socket.service(), m_completion.result_handle); + m_ec = m_completion.result; return std::nullopt; } @@ -569,9 +559,9 @@ namespace asyncpp::io { } inline size_t socket_send_to_awaitable::await_resume() { - if (m_completion.result >= 0) return static_cast(m_completion.result); - if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return m_completion.result_size; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; return 0; } @@ -583,9 +573,9 @@ namespace asyncpp::io { } inline std::pair socket_recv_from_awaitable::await_resume() { - if (m_completion.result >= 0) return {static_cast(m_completion.result), m_source}; - if (m_ec == nullptr) throw std::system_error(std::error_code(-m_completion.result, std::system_category())); - *m_ec = std::error_code(-m_completion.result, std::system_category()); + if (!m_completion.result) return {m_completion.result_size, m_source}; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; return {}; } @@ -602,8 +592,7 @@ namespace asyncpp::io { static void handle(void* ptr) { auto that = static_cast(ptr); - that->real_cb(that->result < 0 ? std::error_code(-that->result, std::system_category()) - : std::error_code()); + that->real_cb(that->result); delete that; }; }; @@ -625,10 +614,10 @@ namespace asyncpp::io { static void handle(void* ptr) { auto that = static_cast(ptr); - if (that->result < 0) - that->real_cb(std::error_code(-that->result, std::system_category())); + if (that->result) + that->real_cb(that->result); else - that->real_cb(socket::from_fd(that->service(), that->result)); + that->real_cb(socket::from_fd(that->service(), that->result_handle)); delete that; }; @@ -650,10 +639,10 @@ namespace asyncpp::io { static void handle(void* ptr) { auto that = static_cast(ptr); - if (that->result < 0) - that->real_cb(0, std::error_code(-that->result, std::system_category())); + if (that->result) + that->real_cb(0, that->result); else - that->real_cb(that->result, {}); + that->real_cb(that->result_size, {}); delete that; }; @@ -675,10 +664,10 @@ namespace asyncpp::io { static void handle(void* ptr) { auto that = static_cast(ptr); - if (that->result < 0) - that->real_cb(0, std::error_code(-that->result, std::system_category())); + if (that->result) + that->real_cb(0, that->result); else - that->real_cb(that->result, {}); + that->real_cb(that->result_size, {}); delete that; }; @@ -701,10 +690,10 @@ namespace asyncpp::io { static void handle(void* ptr) { auto that = static_cast(ptr); - if (that->result < 0) - that->real_cb(0, std::error_code(-that->result, std::system_category())); + if (that->result) + that->real_cb(0, that->result); else - that->real_cb(that->result, {}); + that->real_cb(that->result_size, {}); delete that; }; @@ -727,10 +716,10 @@ namespace asyncpp::io { static void handle(void* ptr) { auto that = static_cast(ptr); - if (that->result < 0) - that->real_cb(0, {}, std::error_code(-that->result, std::system_category())); + if (that->result) + that->real_cb(0, {}, that->result); else - that->real_cb(that->result, that->source, {}); + that->real_cb(that->result_size, that->source, {}); delete that; }; diff --git a/src/address.cpp b/src/address.cpp index 289f148..a1a4ee0 100644 --- a/src/address.cpp +++ b/src/address.cpp @@ -170,7 +170,9 @@ namespace asyncpp::io { switch (m_type) { case address_type::ipv4: return m_ipv4.to_sockaddr(); case address_type::ipv6: return m_ipv6.to_sockaddr(); +#ifndef _WIN32 case address_type::uds: return m_uds.to_sockaddr(); +#endif } return {}; } diff --git a/src/block_allocator.h b/src/block_allocator.h deleted file mode 100644 index d57901c..0000000 --- a/src/block_allocator.h +++ /dev/null @@ -1,106 +0,0 @@ -#pragma once -#include - -#include -#include -#include -#include -#include -#include - -namespace asyncpp::io::detail { - - template - class block_allocator { - - struct page { - page* next_page{}; - uint64_t usage{}; - alignas(T) std::array storage{}; - }; - - TMutex m_mtx{}; - page* m_first_page{}; - - public: - constexpr block_allocator() noexcept = default; - block_allocator(const block_allocator&) = delete; - block_allocator& operator=(const block_allocator&) = delete; - ~block_allocator() noexcept { - auto p = m_first_page; - while (p != nullptr) { - auto ptr = p; - p = p->next_page; - assert(ptr->usage == 0); - delete ptr; - } - } - void* allocate() noexcept { - std::unique_lock lck{m_mtx}; - page* p = m_first_page; - page** page_ptr = &m_first_page; - while (p != nullptr) { - if (p->usage != std::numeric_limits::max()) { - auto free_block = std::countr_one(p->usage); - assert(free_block < 64 && free_block >= 0); - p->usage |= (static_cast(1) << free_block); -#if ASYNCPP_HAS_ASAN - __asan_unpoison_memory_region(p->storage.data() + sizeof(T) * free_block, sizeof(T)); -#endif - return p->storage.data() + sizeof(T) * free_block; - } - page_ptr = &p->next_page; - p = p->next_page; - } - // No free blocks left - p = *page_ptr = new (std::nothrow) page{}; - if (p == nullptr) return nullptr; - p->usage |= 1; -#if ASYNCPP_HAS_ASAN - __asan_poison_memory_region(p->storage.data() + sizeof(T), p->storage.size() - sizeof(T)); -#endif - return p->storage.data(); - } - void deallocate(void* ptr) noexcept { - std::unique_lock lck{m_mtx}; - page* p = m_first_page; - while (p != nullptr) { - if (ptr >= p->storage.data() && ptr < p->storage.data() + p->storage.size()) { -#if ASYNCPP_HAS_ASAN - __asan_poison_memory_region(ptr, sizeof(T)); -#endif - const auto offset = static_cast(ptr) - p->storage.data(); - assert(offset % sizeof(T) == 0); - assert(offset < sizeof(T) * 64); - const auto idx = offset / sizeof(T); - assert((p->usage & static_cast(1) << idx) != 0); - p->usage &= ~(static_cast(1) << idx); - return; - } - p = p->next_page; - } - } - template - T* create(Args&&... args) { - auto ptr = allocate(); - if (ptr == nullptr) return nullptr; - if constexpr (std::is_nothrow_constructible_v) { - return new (ptr) T(std::forward(args)...); - } else { - try { - return new (ptr) T(std::forward(args)...); - } catch (...) { - this->deallocate(ptr); - throw; - } - } - // unreachable - } - void destroy(T* obj) { - if (obj != nullptr) { - obj->~T(); - this->deallocate(obj); - } - } - }; -} // namespace asyncpp::io::detail diff --git a/src/dns.cpp b/src/dns.cpp index 02245b6..d0068c6 100644 --- a/src/dns.cpp +++ b/src/dns.cpp @@ -813,6 +813,9 @@ namespace asyncpp::io::dns { client::client(asyncpp::io::io_service& service) : m_socket_ipv4(socket::create_udp(service, address_type::ipv4)), m_socket_ipv6(socket::create_udp(service, address_type::ipv6)) { + // Required for windows + m_socket_ipv4.bind({ipv4_address::any(), 0}); + m_socket_ipv6.bind({ipv6_address::any(), 0}); launch([](client* that) -> task<> { auto token = that->m_stop.get_token(); std::array buf; diff --git a/src/file.cpp b/src/file.cpp index 659c448..7a9a6f9 100644 --- a/src/file.cpp +++ b/src/file.cpp @@ -1,14 +1,5 @@ #include -#ifndef _WIN32 -#include -#include -#include -#include -#else -#include -#endif - namespace asyncpp::io { file::file(io_service& io) : m_io(&io), m_fd(detail::io_engine::invalid_file_handle) {} file::file(io_service& io, detail::io_engine::file_handle_t fd) : m_io(&io), m_fd(fd) {} @@ -22,7 +13,7 @@ namespace asyncpp::io { file::file(file&& other) noexcept : m_io(std::exchange(other.m_io, nullptr)), m_fd(std::exchange(other.m_fd, detail::io_engine::invalid_file_handle)) {} - file& file::operator=(file&& other) { + file& file::operator=(file&& other) noexcept { close(); m_io = std::exchange(other.m_io, nullptr); m_fd = std::exchange(other.m_fd, detail::io_engine::invalid_file_handle); @@ -31,45 +22,20 @@ namespace asyncpp::io { file::~file() { close(); } void file::open(const char* filename, std::ios_base::openmode mode) { -#ifndef _WIN32 - if ((mode & std::ios_base::ate) == std::ios_base::ate) throw std::logic_error("unsupported flag"); - int m = 0; - if ((mode & std::ios_base::app) == std::ios_base::app) m |= O_APPEND; - if ((mode & std::ios_base::in) == std::ios_base::in) - m |= ((mode & std::ios_base::out) == std::ios_base::out) ? O_RDWR : O_RDONLY; - else if ((mode & std::ios_base::out) == std::ios_base::out) - m |= O_WRONLY; - else - throw std::invalid_argument("neither std::ios::in, nor std::ios::out was specified"); - if ((mode & std::ios_base::trunc) == std::ios_base::trunc) m |= O_TRUNC; - auto res = ::open(filename, m, 0660); - if (res < 0) throw std::system_error(errno, std::system_category()); -#else - DWORD access_mode = 0; - if ((mode & std::ios_base::in) == std::ios_base::in) access_mode |= GENERIC_READ; - if ((mode & std::ios_base::out) == std::ios_base::out) access_mode |= GENERIC_WRITE; - if ((mode & (std::ios_base::in | std::ios_base::out)) == 0) - throw std::invalid_argument("neither std::ios::in, nor std::ios::out was specified"); - HANDLE h = CreateFileA(filename, access_mode, 0, NULL, CREATE_NEW, FILE_ATTRIBUTE_NORMAL, NULL); - // TODO: Remaining code -#endif + auto res = m_io->engine()->file_open(filename, mode); close(); m_fd = res; } void file::open(const std::string& filename, std::ios_base::openmode mode) { return open(filename.c_str(), mode); } void file::open(const std::filesystem::path& filename, std::ios_base::openmode mode) { - return open(filename.c_str(), mode); + return open(filename.string().c_str(), mode); } bool file::is_open() const noexcept { return m_io != nullptr && m_fd != detail::io_engine::invalid_file_handle; } void file::close() { if (m_fd != detail::io_engine::invalid_file_handle) { -#ifndef _WIN32 - ::close(m_fd); -#else - ::CloseHandle(m_fd); -#endif + m_io->engine()->file_close(m_fd); m_fd = detail::io_engine::invalid_file_handle; } } @@ -79,18 +45,5 @@ namespace asyncpp::io { std::swap(m_fd, other.m_fd); } - uint64_t file::size() { -#ifdef __APPLE__ - struct stat info {}; - auto res = fstat(m_fd, &info); -#elif defined(_WIN32) - struct _stat64 info {}; - auto res = _fstat64(m_fd, &info); -#else - struct stat64 info {}; - auto res = fstat64(m_fd, &info); -#endif - if (res < 0) throw std::system_error(errno, std::system_category()); - return info.st_size; - } + uint64_t file::size() { return m_io->engine()->file_size(m_fd); } } // namespace asyncpp::io diff --git a/src/io_engine.cpp b/src/io_engine.cpp index d2d3516..c853e5c 100644 --- a/src/io_engine.cpp +++ b/src/io_engine.cpp @@ -3,10 +3,12 @@ #include namespace asyncpp::io::detail { - // Select is always supported + // Select is always supported on posix std::unique_ptr create_io_engine_select(); // Only supported on Linux on kernel 5.1+ std::unique_ptr create_io_engine_uring(); + // Win32 completion queue + std::unique_ptr create_io_engine_iocp(); std::unique_ptr create_io_engine() { if (const auto env = getenv("ASYNCPP_IO_ENGINE"); env != nullptr) { @@ -15,10 +17,16 @@ namespace asyncpp::io::detail { return create_io_engine_uring(); else if (engine == "select") return create_io_engine_select(); + else if (engine == "iocp") + return create_io_engine_iocp(); else if (!engine.empty()) throw std::runtime_error("unknown io engine " + std::string(engine)); } +#ifdef _WIN32 + return create_io_engine_iocp(); +#else if (auto uring = create_io_engine_uring(); uring != nullptr) return uring; return create_io_engine_select(); +#endif } } // namespace asyncpp::io::detail diff --git a/src/io_engine_generic_unix.cpp b/src/io_engine_generic_unix.cpp new file mode 100644 index 0000000..a63e266 --- /dev/null +++ b/src/io_engine_generic_unix.cpp @@ -0,0 +1,181 @@ +#ifndef _WIN32 +#include "io_engine_generic_unix.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace asyncpp::io::detail { + + io_engine::socket_handle_t io_engine_generic_unix::socket_create(address_type domain, socket_type type) { + int afdomain = -1; + switch (domain) { + case address_type::ipv4: afdomain = AF_INET; break; + case address_type::ipv6: afdomain = AF_INET6; break; + case address_type::uds: afdomain = AF_UNIX; break; + } + int stype = -1; + switch (type) { + case socket_type::stream: stype = SOCK_STREAM; break; + case socket_type::dgram: stype = SOCK_DGRAM; break; + case socket_type::seqpacket: stype = SOCK_SEQPACKET; break; + } + if (afdomain == -1) throw std::system_error(std::make_error_code(std::errc::not_supported)); + if (stype == -1) throw std::system_error(std::make_error_code(std::errc::not_supported)); +#ifdef __APPLE__ + auto fd = ::socket(afdomain, stype, 0); + if (fd < 0) throw std::system_error(errno, std::system_category(), "socket failed"); + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0 || fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0 || fcntl(fd, F_SETFD, FD_CLOEXEC) < 0) { + close(fd); + throw std::system_error(errno, std::system_category(), "fcntl failed"); + } +#else + auto fd = ::socket(afdomain, stype | SOCK_CLOEXEC | SOCK_NONBLOCK, 0); + if (fd < 0) throw std::system_error(errno, std::system_category(), "socket failed"); +#endif + if (domain == address_type::ipv6) { + int opt = 0; + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) { + close(fd); + throw std::system_error(errno, std::system_category(), "setsockopt failed"); + } + } + return fd; + } + + std::pair + io_engine_generic_unix::socket_create_connected_pair(address_type domain, socket_type type) { + int afdomain = -1; + switch (domain) { + case address_type::ipv4: afdomain = AF_INET; break; + case address_type::ipv6: afdomain = AF_INET6; break; + case address_type::uds: afdomain = AF_UNIX; break; + } + int stype = -1; + switch (type) { + case socket_type::stream: stype = SOCK_STREAM; break; + case socket_type::dgram: stype = SOCK_DGRAM; break; + case socket_type::seqpacket: stype = SOCK_SEQPACKET; break; + } + if (afdomain == -1) throw std::system_error(std::make_error_code(std::errc::not_supported)); + if (stype == -1) throw std::system_error(std::make_error_code(std::errc::not_supported)); + + int socks[2]; +#ifndef __APPLE__ + if (socketpair(afdomain, stype | SOCK_CLOEXEC | SOCK_NONBLOCK, 0, socks) != 0) + throw std::system_error(errno, std::system_category(), "socket failed"); +#else + if (socketpair(afdomain, stype, 0, socks) != 0) + throw std::system_error(errno, std::system_category(), "socket failed"); + int flags0 = fcntl(socks[0], F_GETFL, 0); + int flags1 = fcntl(socks[1], F_GETFL, 0); + if (flags0 < 0 || flags1 < 0 || // + fcntl(socks[0], F_SETFL, flags0 | O_NONBLOCK) < 0 || fcntl(socks[1], F_SETFL, flags1 | O_NONBLOCK) < 0 || // + fcntl(socks[0], F_SETFD, FD_CLOEXEC) < 0 || fcntl(socks[1], F_SETFD, FD_CLOEXEC) < 0) { + close(socks[0]); + close(socks[1]); + throw std::system_error(errno, std::system_category(), "fcntl failed"); + } +#endif + return {socks[0], socks[1]}; + } + + void io_engine_generic_unix::socket_close(socket_handle_t socket) { + if (socket >= 0) close(socket); + } + + void io_engine_generic_unix::socket_bind(socket_handle_t socket, endpoint ep) { + auto sa = ep.to_sockaddr(); + auto res = ::bind(socket, reinterpret_cast(&sa.first), sa.second); + if (res < 0) throw std::system_error(errno, std::system_category(), "bind failed"); + } + + void io_engine_generic_unix::socket_listen(socket_handle_t socket, size_t backlog) { + if (backlog == 0) backlog = 20; + auto res = ::listen(socket, backlog); + if (res < 0) throw std::system_error(errno, std::system_category(), "listen failed"); + } + + endpoint io_engine_generic_unix::socket_local_endpoint(socket_handle_t socket) { + sockaddr_storage sa; + socklen_t sa_size = sizeof(sa); + auto res = getsockname(socket, reinterpret_cast(&sa), &sa_size); + if (res >= 0) return endpoint(sa, sa_size); + throw std::system_error(errno, std::system_category(), "getsockname failed"); + } + + endpoint io_engine_generic_unix::socket_remote_endpoint(socket_handle_t socket) { + sockaddr_storage sa; + socklen_t sa_size = sizeof(sa); + auto res = getpeername(socket, reinterpret_cast(&sa), &sa_size); + if (res >= 0) + return endpoint(sa, sa_size); + else if (res < 0 && errno != ENOTCONN) + throw std::system_error(errno, std::system_category(), "getpeername failed"); + return {}; + } + + void io_engine_generic_unix::socket_enable_broadcast(socket_handle_t socket, bool enable) { + int opt = enable ? 1 : 0; + auto res = setsockopt(socket, SOL_SOCKET, SO_BROADCAST, &opt, sizeof(opt)); + if (res < 0) throw std::system_error(errno, std::system_category(), "setsockopt failed"); + } + + void io_engine_generic_unix::socket_shutdown(socket_handle_t socket, bool receive, bool send) { + int mode = 0; + if (receive && send) + mode = SHUT_RDWR; + else if (receive) + mode = SHUT_RD; + else if (send) + mode = SHUT_WR; + else + return; + auto res = ::shutdown(socket, mode); + if (res < 0 && errno != ENOTCONN) throw std::system_error(errno, std::system_category(), "shutdown failed"); + } + + io_engine::file_handle_t io_engine_generic_unix::file_open(const char* filename, std::ios_base::openmode mode) { + if ((mode & std::ios_base::ate) == std::ios_base::ate) throw std::logic_error("unsupported flag"); + int m = 0; + if ((mode & std::ios_base::app) == std::ios_base::app) m |= O_APPEND; + if ((mode & std::ios_base::in) == std::ios_base::in) + m |= ((mode & std::ios_base::out) == std::ios_base::out) ? O_RDWR : O_RDONLY; + else if ((mode & std::ios_base::out) == std::ios_base::out) + m |= O_WRONLY; + else + throw std::invalid_argument("neither std::ios::in, nor std::ios::out was specified"); + if ((mode & std::ios_base::trunc) == std::ios_base::trunc) m |= O_TRUNC; + auto res = ::open(filename, m, 0660); + if (res < 0) throw std::system_error(errno, std::system_category()); + return res; + } + + void io_engine_generic_unix::file_close(file_handle_t fd) { + if (fd >= 0) ::close(fd); + } + + uint64_t io_engine_generic_unix::file_size(file_handle_t fd) { +#ifdef __APPLE__ + struct stat info {}; + auto res = fstat(fd, &info); + if (res < 0) throw std::system_error(errno, std::system_category()); + return info.st_size; +#else + struct stat64 info {}; + auto res = fstat64(fd, &info); + if (res < 0) throw std::system_error(errno, std::system_category()); + return info.st_size; +#endif + } + +} // namespace asyncpp::io::detail + +#endif diff --git a/src/io_engine_generic_unix.h b/src/io_engine_generic_unix.h new file mode 100644 index 0000000..3d4cb28 --- /dev/null +++ b/src/io_engine_generic_unix.h @@ -0,0 +1,28 @@ +#ifndef _WIN32 +#include + +namespace asyncpp::io::detail { + + class io_engine_generic_unix : public io_engine { + public: + socket_handle_t socket_create(address_type domain, socket_type type) override; + std::pair socket_create_connected_pair(address_type domain, + socket_type type) override; + void socket_close(socket_handle_t socket) override; + void socket_bind(socket_handle_t socket, endpoint ep) override; + void socket_listen(socket_handle_t socket, size_t backlog) override; + endpoint socket_local_endpoint(socket_handle_t socket) override; + endpoint socket_remote_endpoint(socket_handle_t socket) override; + void socket_enable_broadcast(socket_handle_t socket, bool enable) override; + void socket_shutdown(socket_handle_t socket, bool receive, bool send) override; + + file_handle_t file_open(const char* filename, std::ios_base::openmode mode) override; + void file_close(file_handle_t fd) override; + uint64_t file_size(file_handle_t fd) override; + + private: + }; + +} // namespace asyncpp::io::detail + +#endif diff --git a/src/io_engine_iocp.cpp b/src/io_engine_iocp.cpp new file mode 100644 index 0000000..c709f9b --- /dev/null +++ b/src/io_engine_iocp.cpp @@ -0,0 +1,644 @@ +#include + +#ifndef _WIN32 +namespace asyncpp::io::detail { + std::unique_ptr create_io_engine_iocp() { return nullptr; } +} // namespace asyncpp::io::detail +#else + +#include +#include +#include + +#include +#include +#include +#include // This needs be included before the ones below, otherwise INETADDR_SETANY breaks + +#include +#include +#include + +extern "C" { +typedef struct _IO_STATUS_BLOCK { + union { + LONG Status; + PVOID Pointer; + }; + ULONG Information; +} IO_STATUS_BLOCK, *PIO_STATUS_BLOCK; + +typedef enum _FILE_INFORMATION_CLASS { + FileReplaceCompletionInformation = 61 +} FILE_INFORMATION_CLASS, + *PFILE_INFORMATION_CLASS; + +typedef struct _FILE_COMPLETION_INFORMATION { + HANDLE Port; + PVOID Key; +} FILE_COMPLETION_INFORMATION, *PFILE_COMPLETION_INFORMATION; + +NTSYSAPI NTSTATUS NTAPI NtSetInformationFile(IN HANDLE FileHandle, OUT PIO_STATUS_BLOCK IoStatusBlock, + IN PVOID FileInformation, IN ULONG Length, + IN FILE_INFORMATION_CLASS FileInformationClass); +} + +namespace asyncpp::io::detail { + + struct iocp_engine_state { + WSAOVERLAPPED overlapped; + HANDLE handle = io_engine::invalid_file_handle; + SOCKET accept_sock = INVALID_SOCKET; + union { + std::array accept_buffer{}; + struct { + endpoint* recv_from_ep; + sockaddr_storage recv_from_sa; + int recv_from_sa_len; + }; + }; + }; + static_assert(offsetof(iocp_engine_state, overlapped) == 0, + "Code assumes that overlapped is at the start of engine_data"); + + class io_engine_iocp : public io_engine { + public: + io_engine_iocp(); + io_engine_iocp(const io_engine_iocp&) = delete; + io_engine_iocp& operator=(const io_engine_iocp&) = delete; + ~io_engine_iocp(); + + std::string_view name() const noexcept override; + + size_t run(bool nowait) override; + void wake() override; + + socket_handle_t socket_create(address_type domain, socket_type type) override; + std::pair socket_create_connected_pair(address_type domain, + socket_type type) override; + void socket_register(socket_handle_t socket) override; + void socket_release(socket_handle_t socket) override; + void socket_close(socket_handle_t socket) override; + void socket_bind(socket_handle_t socket, endpoint ep) override; + void socket_listen(socket_handle_t socket, size_t backlog) override; + endpoint socket_local_endpoint(socket_handle_t socket) override; + endpoint socket_remote_endpoint(socket_handle_t socket) override; + void socket_enable_broadcast(socket_handle_t socket, bool enable) override; + void socket_shutdown(socket_handle_t socket, bool receive, bool send) override; + bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) override; + bool enqueue_accept(socket_handle_t socket, completion_data* cd) override; + bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) override; + bool enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) override; + bool enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) override; + bool enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) override; + + file_handle_t file_open(const char* filename, std::ios_base::openmode mode) override; + void file_register(file_handle_t fd) override; + void file_release(file_handle_t fd) override; + void file_close(file_handle_t fd) override; + uint64_t file_size(file_handle_t fd) override; + bool enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, completion_data* cd) override; + bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, + completion_data* cd) override; + bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) override; + + bool cancel(completion_data* cd) override; + + private: + HANDLE m_completion_port = INVALID_HANDLE_VALUE; + std::atomic m_inflight_count{}; + }; + + std::unique_ptr create_io_engine_iocp() { return std::make_unique(); } + + io_engine_iocp::io_engine_iocp() { + WSADATA wsaData; + if (int res = WSAStartup(MAKEWORD(2, 2), &wsaData); res != 0) + throw std::runtime_error("failed to initialize WSA"); + m_completion_port = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); + if (m_completion_port == NULL) { + WSACleanup(); + throw std::runtime_error("failed to create completion port"); + } + } + + io_engine_iocp::~io_engine_iocp() { + if (m_completion_port != INVALID_HANDLE_VALUE) CloseHandle(m_completion_port); + WSACleanup(); + } + + std::string_view io_engine_iocp::name() const noexcept { return "iocp"; } + + size_t io_engine_iocp::run(bool nowait) { + DWORD timeout = 0; + if (!nowait) timeout = 10000; + + DWORD num_transfered; + ULONG_PTR key; + LPOVERLAPPED overlapped; + if (GetQueuedCompletionStatus(m_completion_port, &num_transfered, &key, &overlapped, timeout) == FALSE && + overlapped == nullptr) { + return m_inflight_count; + } + if (key == 1) return m_inflight_count; + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + auto state = reinterpret_cast(overlapped); + auto cd = reinterpret_cast(overlapped); + + DWORD num_bytes, flags; + auto res = GetOverlappedResult(state->handle, &state->overlapped, &num_bytes, FALSE); + if (res == TRUE) { + cd->result.clear(); + if (state->accept_sock != INVALID_SOCKET) { + if (setsockopt(state->accept_sock, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, + reinterpret_cast(&state->handle), sizeof(state->handle)) == SOCKET_ERROR) { + closesocket(state->accept_sock); + cd->result = std::error_code(GetLastError(), std::system_category()); + return true; + } + cd->result_handle = state->accept_sock; + } else { + cd->result_size = num_bytes; + if (state->recv_from_ep != nullptr) { + *state->recv_from_ep = endpoint(state->recv_from_sa, state->recv_from_sa_len); + } + } + } else { + auto err = GetLastError(); + if (state->accept_sock != INVALID_SOCKET) closesocket(state->accept_sock); + switch (err) { + case WSANOTINITIALISED: + case WSAENETDOWN: + case WSAENOTSOCK: + case WSA_INVALID_HANDLE: + case WSA_INVALID_PARAMETER: + case WSA_IO_INCOMPLETE: + case WSAEFAULT: throw std::system_error(err, std::system_category(), "GetOverlappedResult failed"); + default: cd->result = std::error_code(err, std::system_category()); + } + } + + if (cd->callback) cd->callback(cd->userdata); + + return m_inflight_count; + } + + void io_engine_iocp::wake() { + if (PostQueuedCompletionStatus(m_completion_port, 0, 1, NULL) == FALSE) + throw std::runtime_error("failed to wake cq"); + } + + io_engine::socket_handle_t io_engine_iocp::socket_create(address_type domain, socket_type type) { + int afdomain = -1; + switch (domain) { + case address_type::ipv4: afdomain = AF_INET; break; + case address_type::ipv6: afdomain = AF_INET6; break; + } + int stype = -1; + switch (type) { + case socket_type::stream: stype = SOCK_STREAM; break; + case socket_type::dgram: stype = SOCK_DGRAM; break; + case socket_type::seqpacket: stype = SOCK_SEQPACKET; break; + } + if (afdomain == -1) throw std::system_error(std::make_error_code(std::errc::not_supported)); + if (stype == -1) throw std::system_error(std::make_error_code(std::errc::not_supported)); + auto fd = WSASocket(afdomain, stype, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (fd == INVALID_SOCKET) throw std::system_error(WSAGetLastError(), std::system_category(), "WSASocket"); + u_long mode = 1; + if (ioctlsocket(fd, FIONBIO, &mode) == SOCKET_ERROR) { + closesocket(fd); + throw std::system_error(WSAGetLastError(), std::system_category(), "ioctlsocket failed"); + } + if (domain == address_type::ipv6) { + DWORD opt = 0; + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&opt), sizeof(opt)) == + SOCKET_ERROR) { + closesocket(fd); + throw std::system_error(WSAGetLastError(), std::system_category(), "setsockopt failed"); + } + } + // Add socket to completion port + if (CreateIoCompletionPort((HANDLE)fd, m_completion_port, 0, 0) == NULL) { + closesocket(fd); + throw std::system_error(GetLastError(), std::system_category(), "CreateIoCompletionPort failed"); + } + return fd; + } + + std::pair + io_engine_iocp::socket_create_connected_pair(address_type domain, socket_type type) { + if (type != socket_type::stream) + throw std::system_error(std::make_error_code(std::errc::function_not_supported), "unsupported socket type"); + + auto close_and_throw = [](const char* name, auto... sockets) { + auto err = WSAGetLastError(); + (::closesocket(sockets), ...); + throw std::system_error(err, std::system_category(), name); + }; + + auto listener = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED); + if (listener == INVALID_SOCKET) close_and_throw("WSASocket"); + + int reuse = 1; + if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, (char*)&reuse, (socklen_t)sizeof(reuse)) == -1) + close_and_throw("setsockopt", listener); + + struct sockaddr_in inaddr {}; + inaddr.sin_family = AF_INET; + inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + if (bind(listener, reinterpret_cast(&inaddr), sizeof(inaddr)) == SOCKET_ERROR) + close_and_throw("bind", listener); + + inaddr = {}; + int addrlen = sizeof(inaddr); + if (getsockname(listener, reinterpret_cast(&inaddr), &addrlen) == SOCKET_ERROR) + close_and_throw("getsockname", listener); + // win32 getsockname may only set the port number + inaddr.sin_family = AF_INET; + inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + if (listen(listener, 1) == SOCKET_ERROR) close_and_throw("listen", listener); + + auto sock0 = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (sock0 == INVALID_SOCKET) close_and_throw("WSASocket", listener); + if (connect(sock0, reinterpret_cast(&inaddr), sizeof(inaddr)) == SOCKET_ERROR) + close_and_throw("connect", listener, sock0); + + auto sock1 = accept(listener, NULL, NULL); + if (sock1 == INVALID_SOCKET) close_and_throw("accept", listener, sock0); + + closesocket(listener); + + u_long mode = 1; + if (ioctlsocket(sock0, FIONBIO, &mode) == SOCKET_ERROR) close_and_throw("ioctlsocket", listener, sock0, sock1); + mode = 1; + if (ioctlsocket(sock1, FIONBIO, &mode) == SOCKET_ERROR) close_and_throw("ioctlsocket", listener, sock0, sock1); + + // Add socket to completion port + if (CreateIoCompletionPort((HANDLE)sock0, m_completion_port, 0, 0) == NULL) + close_and_throw("CreateIoCompletionPort", listener, sock0, sock1); + if (CreateIoCompletionPort((HANDLE)sock1, m_completion_port, 0, 0) == NULL) + close_and_throw("CreateIoCompletionPort", listener, sock0, sock1); + + return {sock0, sock1}; + } + + void io_engine_iocp::socket_register(socket_handle_t socket) { + // Make socket non blocking (do we even need to do this ?) + u_long mode = 1; + if (ioctlsocket(socket, FIONBIO, &mode) == SOCKET_ERROR) + throw std::system_error(WSAGetLastError(), std::system_category(), "ioctlsocket failed"); + // Add socket to completion port + if (CreateIoCompletionPort((HANDLE)socket, m_completion_port, 0, 0) == NULL) + throw std::system_error(GetLastError(), std::system_category(), "CreateIoCompletionPort failed"); + } + + void io_engine_iocp::socket_release(socket_handle_t socket) { + // Unhook the socket from our completion port + // Note: Dark magic ahead + _IO_STATUS_BLOCK status{}; + FILE_COMPLETION_INFORMATION info{0, NULL}; + if (NtSetInformationFile((HANDLE)socket, &status, &info, sizeof(info), FileReplaceCompletionInformation) < 0) + throw std::system_error(std::make_error_code(std::errc::io_error), "NtSetInformationFile failed"); + } + + void io_engine_iocp::socket_close(socket_handle_t socket) { + if (socket != INVALID_SOCKET) closesocket(socket); + } + + void io_engine_iocp::socket_bind(socket_handle_t socket, endpoint ep) { + auto sa = ep.to_sockaddr(); + auto res = ::bind(socket, reinterpret_cast(&sa.first), sa.second); + if (res < 0) throw std::system_error(WSAGetLastError(), std::system_category()); + } + + void io_engine_iocp::socket_listen(socket_handle_t socket, size_t backlog) { + if (backlog == 0) backlog = 20; + auto res = ::listen(socket, backlog); + if (res == SOCKET_ERROR) throw std::system_error(WSAGetLastError(), std::system_category()); + } + + endpoint io_engine_iocp::socket_local_endpoint(socket_handle_t socket) { + sockaddr_storage sa; + int sa_size = sizeof(sa); + auto res = getsockname(socket, reinterpret_cast(&sa), &sa_size); + if (res >= 0) return endpoint(sa, sa_size); + throw std::system_error(WSAGetLastError(), std::system_category()); + } + + endpoint io_engine_iocp::socket_remote_endpoint(socket_handle_t socket) { + sockaddr_storage sa; + int sa_size = sizeof(sa); + auto res = getpeername(socket, reinterpret_cast(&sa), &sa_size); + if (res >= 0) + return endpoint(sa, sa_size); + else if (res == SOCKET_ERROR && WSAGetLastError() != WSAENOTCONN) + throw std::system_error(WSAGetLastError(), std::system_category()); + return {}; + } + + void io_engine_iocp::socket_enable_broadcast(socket_handle_t socket, bool enable) { + BOOL opt = enable ? TRUE : FALSE; + auto res = setsockopt(socket, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&opt), sizeof(opt)); + if (res == SOCKET_ERROR) throw std::system_error(WSAGetLastError(), std::system_category()); + } + + void io_engine_iocp::socket_shutdown(socket_handle_t socket, bool receive, bool send) { + int mode = 0; + if (receive && send) + mode = SD_BOTH; + else if (receive) + mode = SD_RECEIVE; + else if (send) + mode = SD_SEND; + else + return; + auto res = ::shutdown(socket, mode); + if (res == SOCKET_ERROR && WSAGetLastError() != WSAENOTCONN) + throw std::system_error(WSAGetLastError(), std::system_category()); + } + + bool io_engine_iocp::enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) { + auto sa = ep.to_sockaddr(); + LPFN_CONNECTEX lpfnConnectex = nullptr; + GUID b = WSAID_CONNECTEX; + DWORD n; + if (WSAIoctl(socket, SIO_GET_EXTENSION_FUNCTION_POINTER, &b, sizeof(b), &lpfnConnectex, sizeof(lpfnConnectex), + &n, NULL, NULL) == SOCKET_ERROR) { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + return true; + } + + // ConnectEx requires the socket to be bound + { + WSAPROTOCOL_INFO info{}; + int optlen = sizeof(info); + if (getsockopt(socket, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast(&info), &optlen) == + SOCKET_ERROR) { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + return true; + } + sockaddr_storage addr{}; + addr.ss_family = info.iAddressFamily; + INETADDR_SETANY(reinterpret_cast(&addr)); + auto res = ::bind(socket, reinterpret_cast(&addr), (int)INET_SOCKADDR_LENGTH(addr.ss_family)); + if (res < 0) { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + return true; + } + } + + auto state = cd->es_init(); + state->handle = (HANDLE)socket; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + if (lpfnConnectex(socket, reinterpret_cast(&sa.first), sa.second, nullptr, 0, nullptr, + &state->overlapped) == TRUE || + WSAGetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + bool io_engine_iocp::enqueue_accept(socket_handle_t socket, completion_data* cd) { + auto state = cd->es_init(); + state->handle = (HANDLE)socket; + + // Get the socket family to create a second socket for accepting + WSAPROTOCOL_INFO info{}; + int optlen = sizeof(info); + if (getsockopt(socket, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast(&info), &optlen) == SOCKET_ERROR) { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + return true; + } + + state->accept_sock = WSASocket(info.iAddressFamily, info.iSocketType, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (state->accept_sock == INVALID_SOCKET) { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + return true; + } + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + DWORD received; + if (AcceptEx(socket, state->accept_sock, state->accept_buffer.data(), 0, sizeof(sockaddr_in6) + 16, + sizeof(sockaddr_in6) + 16, &received, &state->overlapped) == TRUE || + WSAGetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + closesocket(state->accept_sock); + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + + return false; + } + + bool io_engine_iocp::enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) { + auto state = cd->es_init(); + state->handle = (HANDLE)socket; + + WSABUF buffer; + buffer.buf = static_cast(buf); + buffer.len = len; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + DWORD flags = 0; + if (WSARecv(socket, &buffer, 1, nullptr, &flags, &state->overlapped, nullptr) == 0 || + WSAGetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + bool io_engine_iocp::enqueue_send(socket_handle_t socket, const void* buf, size_t len, completion_data* cd) { + auto state = cd->es_init(); + state->handle = (HANDLE)socket; + + WSABUF buffer; + buffer.buf = const_cast(static_cast(buf)); + buffer.len = len; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + if (WSASend(socket, &buffer, 1, nullptr, 0, &state->overlapped, nullptr) == 0 || + WSAGetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + bool io_engine_iocp::enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, + completion_data* cd) { + auto state = cd->es_init(); + state->handle = (HANDLE)socket; + memset(&state->recv_from_sa, 0, sizeof(state->recv_from_sa)); + state->recv_from_sa_len = sizeof(state->recv_from_sa); + state->recv_from_ep = source; + + WSABUF buffer; + buffer.buf = static_cast(buf); + buffer.len = len; + DWORD flags = 0; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + if (WSARecvFrom(socket, &buffer, 1, nullptr, &flags, reinterpret_cast(&state->recv_from_sa), + &state->recv_from_sa_len, &state->overlapped, nullptr) == 0 || + WSAGetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + bool io_engine_iocp::enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, + completion_data* cd) { + auto state = cd->es_init(); + state->handle = (HANDLE)socket; + + auto sa = dst.to_sockaddr(); + + WSABUF buffer; + buffer.buf = const_cast(static_cast(buf)); + buffer.len = len; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + if (WSASendTo(socket, &buffer, 1, nullptr, 0, reinterpret_cast(&sa.first), sa.second, + &state->overlapped, nullptr) == 0 || + WSAGetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(WSAGetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + io_engine::file_handle_t io_engine_iocp::file_open(const char* filename, std::ios_base::openmode mode) { + DWORD access_mode = 0; + if ((mode & std::ios_base::in) == std::ios_base::in) access_mode |= GENERIC_READ; + if ((mode & (std::ios_base::out | std::ios_base::app)) != 0) access_mode |= GENERIC_WRITE; + if ((mode & (std::ios_base::in | std::ios_base::out | std::ios_base::app)) == 0) + throw std::invalid_argument("neither std::ios::in, nor std::ios::out was specified"); + HANDLE res = CreateFileA(filename, access_mode, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, NULL, + OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, NULL); + if (res == INVALID_HANDLE_VALUE) throw std::system_error(GetLastError(), std::system_category()); + if ((mode & std::ios_base::trunc) == std::ios_base::trunc) { + if (SetEndOfFile(res) == FALSE) { + auto err = GetLastError(); + CloseHandle(res); + throw std::system_error(err, std::system_category()); + } + } + if ((mode & (std::ios_base::ate | std::ios_base::app)) != 0) { + LARGE_INTEGER pos; + pos.QuadPart = 0; + if (SetFilePointerEx(res, pos, nullptr, FILE_END) == FALSE) { + auto err = GetLastError(); + CloseHandle(res); + throw std::system_error(err, std::system_category()); + } + } + if (CreateIoCompletionPort(res, m_completion_port, 0, 0) == NULL) { + auto err = GetLastError(); + CloseHandle(res); + throw std::system_error(err, std::system_category(), "CreateIoCompletionPort failed"); + } + return res; + } + + void io_engine_iocp::file_register(file_handle_t fd) { + // Add file to completion port + if (CreateIoCompletionPort(fd, m_completion_port, 0, 0) == NULL) + throw std::system_error(GetLastError(), std::system_category(), "CreateIoCompletionPort failed"); + } + + void io_engine_iocp::file_release(file_handle_t fd) { + // Unhook the file from our completion port + // Note: Dark magic ahead + _IO_STATUS_BLOCK status{}; + FILE_COMPLETION_INFORMATION info{0, NULL}; + if (NtSetInformationFile(fd, &status, &info, sizeof(info), FileReplaceCompletionInformation) < 0) + throw std::system_error(std::make_error_code(std::errc::io_error), "NtSetInformationFile failed"); + } + + void io_engine_iocp::file_close(file_handle_t fd) { ::CloseHandle(fd); } + + uint64_t io_engine_iocp::file_size(file_handle_t fd) { + DWORD high; + auto res = GetFileSize(fd, &high); + if (res == INVALID_FILE_SIZE && GetLastError() != NO_ERROR) + throw std::system_error(GetLastError(), std::system_category()); + return (static_cast(high) << 32) + res; + } + + bool io_engine_iocp::enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, completion_data* cd) { + auto state = cd->es_init(); + state->handle = fd; + state->overlapped.Offset = offset & 0xffffffff; + state->overlapped.OffsetHigh = offset >> 32; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + if (ReadFile(fd, buf, len, nullptr, &state->overlapped) == TRUE || GetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(GetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + bool io_engine_iocp::enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, + completion_data* cd) { + auto state = cd->es_init(); + state->handle = fd; + state->overlapped.Offset = offset & 0xffffffff; + state->overlapped.OffsetHigh = offset >> 32; + + m_inflight_count.fetch_add(1, std::memory_order::relaxed); + if (WriteFile(fd, buf, len, nullptr, &state->overlapped) == TRUE || GetLastError() == WSA_IO_PENDING) { + // IOCP always pushes even if it finishes synchronously + return false; + } else { + cd->result = std::error_code(GetLastError(), std::system_category()); + m_inflight_count.fetch_sub(1, std::memory_order::relaxed); + return true; + } + } + + bool io_engine_iocp::enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) { + // Looks like there is no async version of this + if (FlushFileBuffers(fd) == FALSE) + cd->result = std::error_code(GetLastError(), std::system_category()); + else + cd->result.clear(); + return true; + } + + bool io_engine_iocp::cancel(completion_data* cd) { + auto state = cd->es_get(); + auto res = CancelIoEx(state->handle, &state->overlapped); + return res == TRUE; + } + +} // namespace asyncpp::io::detail +#endif diff --git a/src/io_engine_select.cpp b/src/io_engine_select.cpp index 0557d86..cda47a9 100644 --- a/src/io_engine_select.cpp +++ b/src/io_engine_select.cpp @@ -1,19 +1,23 @@ #include +#ifdef _WIN32 +namespace asyncpp::io::detail { + std::unique_ptr create_io_engine_select() { return nullptr; } +} // namespace asyncpp::io::detail +#else +#include "io_engine_generic_unix.h" + #include #include #include -#ifndef _WIN32 #include #include #include +#include #include +#include #include -#else -#include -#include -#endif #ifdef __linux__ #define USE_EVENTFD @@ -53,7 +57,7 @@ namespace asyncpp::io::detail { }; } // namespace - class io_engine_select : public io_engine { + class io_engine_select : public io_engine_generic_unix { public: io_engine_select(); io_engine_select(const io_engine_select&) = delete; @@ -65,6 +69,8 @@ namespace asyncpp::io::detail { size_t run(bool nowait) override; void wake() override; + void socket_register(socket_handle_t socket) override; + void socket_release(socket_handle_t socket) override; bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) override; bool enqueue_accept(socket_handle_t socket, completion_data* cd) override; bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) override; @@ -74,8 +80,11 @@ namespace asyncpp::io::detail { bool enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, completion_data* cd) override; - bool enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) override; - bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, completion_data* cd) override; + void file_register(file_handle_t fd) override; + void file_release(file_handle_t fd) override; + bool enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, completion_data* cd) override; + bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, + completion_data* cd) override; bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) override; bool cancel(completion_data* cd) override; @@ -118,12 +127,9 @@ namespace asyncpp::io::detail { } io_engine_select::~io_engine_select() { -#ifdef _WIN32 -#else if (m_wake_fd >= 0) close(m_wake_fd); #ifndef USE_EVENTFD if (m_wake_fd_write >= 0) close(m_wake_fd_write); -#endif #endif } @@ -185,9 +191,9 @@ namespace asyncpp::io::detail { int result; socklen_t result_len = sizeof(result); if (getsockopt(e.socket, SOL_SOCKET, SO_ERROR, &result, &result_len) < 0) { - e.done->result = -errno; + e.done->result = std::error_code(errno, std::system_category()); } else { - e.done->result = -result; + e.done->result = std::error_code(result, std::system_category()); } m_done_callbacks.push_back(e.done); return true; @@ -199,12 +205,12 @@ namespace asyncpp::io::detail { e.state.send.len -= res; e.state.send.buf = static_cast(e.state.send.buf) + res; if (e.state.send.len == 0) { - e.done->result = 0; + e.done->result.clear(); m_done_callbacks.push_back(e.done); return true; } } else if (errno != EAGAIN) { - e.done->result = -errno; + e.done->result = std::error_code(errno, std::system_category()); m_done_callbacks.push_back(e.done); return true; } @@ -213,34 +219,43 @@ namespace asyncpp::io::detail { case op::accept: { if ((state & RDY_READ) == 0) return false; auto res = ::accept(e.socket, nullptr, nullptr); - if (res >= 0 || errno != EAGAIN) { - e.done->result = res >= 0 ? res : -errno; - m_done_callbacks.push_back(e.done); - return true; - } - return false; + if (res >= 0) { + e.done->result.clear(); + e.done->result_handle = res; + } else if (errno != EAGAIN) { + e.done->result = std::error_code(errno, std::system_category()); + } else + return false; + m_done_callbacks.push_back(e.done); + return true; } case op::recv: { if ((state & RDY_READ) == 0) return false; auto res = ::recv(e.socket, e.state.recv.buf, e.state.recv.len, 0); - if (res >= 0 || errno != EAGAIN) { - e.done->result = res >= 0 ? res : -errno; - m_done_callbacks.push_back(e.done); - return true; - } - return false; + if (res >= 0) { + e.done->result.clear(); + e.done->result_size = res; + } else if (errno != EAGAIN) { + e.done->result = std::error_code(errno, std::system_category()); + } else + return false; + m_done_callbacks.push_back(e.done); + return true; } case op::send_to: { if ((state & RDY_WRITE) == 0) return false; auto sa = e.state.send_to.destination.to_sockaddr(); auto res = ::sendto(e.socket, e.state.send.buf, e.state.send.len, 0, reinterpret_cast(&sa.first), sa.second); - if (res >= 0 || errno != EAGAIN) { - e.done->result = res >= 0 ? res : -errno; - m_done_callbacks.push_back(e.done); - return true; - } - return false; + if (res >= 0) { + e.done->result.clear(); + e.done->result_size = res; + } else if (errno != EAGAIN) { + e.done->result = std::error_code(errno, std::system_category()); + } else + return false; + m_done_callbacks.push_back(e.done); + return true; } case op::recv_from: { if ((state & RDY_READ) == 0) return false; @@ -248,18 +263,21 @@ namespace asyncpp::io::detail { socklen_t sa_len = sizeof(sa); auto res = ::recvfrom(e.socket, e.state.recv.buf, e.state.recv.len, 0, reinterpret_cast(&sa), &sa_len); - if (res >= 0 || errno != EAGAIN) { - e.done->result = res >= 0 ? res : -errno; - if (res >= 0 && e.state.recv_from.source) { + if (res >= 0) { + e.done->result.clear(); + e.done->result_size = res; + if (e.state.recv_from.source) { if (sa.ss_family == AF_INET || sa.ss_family == AF_INET6 || sa.ss_family == AF_UNIX) *e.state.recv_from.source = endpoint(sa, sa_len); else *e.state.recv_from.source = endpoint{}; } - m_done_callbacks.push_back(e.done); - return true; - } - return false; + } else if (errno != EAGAIN) { + e.done->result = std::error_code(errno, std::system_category()); + } else + return false; + m_done_callbacks.push_back(e.done); + return true; } default: return true; } @@ -274,12 +292,16 @@ namespace asyncpp::io::detail { #endif } + void io_engine_select::socket_register(socket_handle_t socket) {} + + void io_engine_select::socket_release(socket_handle_t socket) {} + bool io_engine_select::enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) { auto sa = ep.to_sockaddr(); auto res = ::connect(socket, reinterpret_cast(&sa.first), sa.second); if (res == 0 || errno != EINPROGRESS) { // Succeeded right away - cd->result = res ? -errno : 0; + cd->result = std::error_code(res ? errno : 0, std::system_category()); return true; } @@ -295,8 +317,12 @@ namespace asyncpp::io::detail { bool io_engine_select::enqueue_accept(socket_handle_t socket, completion_data* cd) { auto res = ::accept(socket, nullptr, nullptr); - if (res >= 0 || errno != EAGAIN) { - cd->result = res >= 0 ? res : -errno; + if (res >= 0) { + cd->result.clear(); + cd->result_handle = res; + return true; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); return true; } @@ -312,8 +338,12 @@ namespace asyncpp::io::detail { bool io_engine_select::enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) { auto res = ::recv(socket, buf, len, 0); - if (res >= 0 || errno != EAGAIN) { - cd->result = res >= 0 ? res : -errno; + if (res >= 0) { + cd->result.clear(); + cd->result_size = res; + return true; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); return true; } @@ -335,11 +365,11 @@ namespace asyncpp::io::detail { len -= res; buf = static_cast(buf) + res; } else if (errno != EAGAIN) { - cd->result = -errno; + cd->result = std::error_code(errno, std::system_category()); return true; } if (len == 0) { - cd->result = 0; + cd->result.clear(); return true; } @@ -360,15 +390,19 @@ namespace asyncpp::io::detail { sockaddr_storage sa; socklen_t sa_len = sizeof(sa); auto res = ::recvfrom(socket, buf, len, 0, reinterpret_cast(&sa), &sa_len); - if (res >= 0 || errno != EAGAIN) { - cd->result = res >= 0 ? res : -errno; - if (res >= 0 && source) { + if (res >= 0) { + cd->result.clear(); + cd->result_size = res; + if (source != nullptr) { if (sa.ss_family == AF_INET || sa.ss_family == AF_INET6) *source = endpoint(sa, sa_len); else *source = endpoint{}; } return true; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); + return true; } entry e{}; @@ -388,8 +422,12 @@ namespace asyncpp::io::detail { completion_data* cd) { auto sa = dst.to_sockaddr(); auto res = ::sendto(socket, buf, len, 0, reinterpret_cast(&sa.first), sa.second); - if (res >= 0 || errno != EAGAIN) { - cd->result = res >= 0 ? res : -errno; + if (res >= 0) { + cd->result.clear(); + cd->result_size = res; + return true; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); return true; } @@ -406,18 +444,33 @@ namespace asyncpp::io::detail { return false; } - bool io_engine_select::enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) { + void io_engine_select::file_register(file_handle_t fd) {} + + void io_engine_select::file_release(file_handle_t fd) {} + + bool io_engine_select::enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, + completion_data* cd) { // There is no way to do async file io on linux without uring, so just do the read inline auto res = pread(fd, buf, len, offset); - cd->result = res >= 0 ? res : -errno; + if (res >= 0) { + cd->result.clear(); + cd->result_size = res; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); + } return true; } - bool io_engine_select::enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, + bool io_engine_select::enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, completion_data* cd) { // There is no way to do async file io on linux without uring, so just do the write inline auto res = pwrite(fd, buf, len, offset); - cd->result = res >= 0 ? res : -errno; + if (res >= 0) { + cd->result.clear(); + cd->result_size = res; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); + } return true; } @@ -428,7 +481,12 @@ namespace asyncpp::io::detail { #else auto res = fsync(fd); #endif - cd->result = res >= 0 ? res : -errno; + if (res >= 0) { + cd->result.clear(); + cd->result_size = res; + } else if (errno != EAGAIN) { + cd->result = std::error_code(errno, std::system_category()); + } return true; } @@ -438,7 +496,7 @@ namespace asyncpp::io::detail { if (it->done == cd) { it = m_inflight.erase(it); lck.unlock(); - cd->result = -ECANCELED; + cd->result = std::error_code(ECANCELED, std::system_category()); cd->callback(cd->userdata); return true; } @@ -447,3 +505,5 @@ namespace asyncpp::io::detail { } } // namespace asyncpp::io::detail + +#endif diff --git a/src/io_engine_uring.cpp b/src/io_engine_uring.cpp index 795e973..3aed09e 100644 --- a/src/io_engine_uring.cpp +++ b/src/io_engine_uring.cpp @@ -5,21 +5,22 @@ namespace asyncpp::io::detail { std::unique_ptr create_io_engine_uring() { return nullptr; } } // namespace asyncpp::io::detail #else +#include "io_engine_generic_unix.h" -#include #include -#include #include + +#include +#include +#include #include #include #include #include -#include "block_allocator.h" - namespace asyncpp::io::detail { - class io_engine_uring : public io_engine { + class io_engine_uring : public io_engine_generic_unix { public: io_engine_uring(struct io_uring ring) noexcept; io_engine_uring(const io_engine_uring&) = delete; @@ -31,6 +32,8 @@ namespace asyncpp::io::detail { size_t run(bool nowait) override; void wake() override; + void socket_register(socket_handle_t socket) override; + void socket_release(socket_handle_t socket) override; bool enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) override; bool enqueue_accept(socket_handle_t socket, completion_data* cd) override; bool enqueue_recv(socket_handle_t socket, void* buf, size_t len, completion_data* cd) override; @@ -40,25 +43,40 @@ namespace asyncpp::io::detail { bool enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, completion_data* cd) override; - bool enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) override; - bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, completion_data* cd) override; + void file_register(file_handle_t fd) override; + void file_release(file_handle_t fd) override; + bool enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, completion_data* cd) override; + bool enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, + completion_data* cd) override; bool enqueue_fsync(file_handle_t fd, fsync_flags flags, completion_data* cd) override; bool cancel(completion_data* cd) override; private: - struct msghdr_info { - struct msghdr hdr {}; - sockaddr_storage sockaddr{}; - iovec data{}; - asyncpp::io::endpoint* real_endpoint{}; + enum class uring_op : uint8_t { + invalid, + connect, + accept, + recv, + send, + recv_from, + send_to, + readv, + writev, + fsync + }; + struct uring_engine_state { + struct msghdr hdr; + sockaddr_storage sockaddr; + iovec data; + asyncpp::io::endpoint* real_endpoint; + uring_op op; }; std::mutex m_sqe_mtx{}; std::mutex m_cqe_mtx{}; std::atomic m_inflight_count{}; struct io_uring m_ring {}; - block_allocator m_state_allocator{}; }; std::unique_ptr create_io_engine_uring() { @@ -106,18 +124,19 @@ namespace asyncpp::io::detail { m_inflight_count.fetch_sub(1, std::memory_order::relaxed); lck.unlock(); - info->result = opres; - - if (auto extra = static_cast(info->engine_state); extra != nullptr) { - if (extra->real_endpoint != nullptr) { - if (extra->sockaddr.ss_family == AF_INET || extra->sockaddr.ss_family == AF_INET6 || - extra->sockaddr.ss_family == AF_UNIX) - *extra->real_endpoint = endpoint(extra->sockaddr, extra->hdr.msg_namelen); - else - *extra->real_endpoint = endpoint(); - } - m_state_allocator.destroy(extra); - info->engine_state = nullptr; + auto state = info->es_get(); + info->result = std::error_code(opres < 0 ? -opres : 0, std::system_category()); + switch (state->op) { + case uring_op::accept: info->result_handle = opres; break; + default: info->result_size = static_cast(opres); break; + } + + if (state->real_endpoint != nullptr) { + if (state->sockaddr.ss_family == AF_INET || state->sockaddr.ss_family == AF_INET6 || + state->sockaddr.ss_family == AF_UNIX) + *state->real_endpoint = endpoint(state->sockaddr, state->hdr.msg_namelen); + else + *state->real_endpoint = endpoint(); } if (info->callback) info->callback(info->userdata); @@ -133,6 +152,10 @@ namespace asyncpp::io::detail { io_uring_submit(&m_ring); } + void io_engine_uring::socket_register(socket_handle_t socket) {} + + void io_engine_uring::socket_release(socket_handle_t socket) {} + bool io_engine_uring::enqueue_connect(socket_handle_t socket, endpoint ep, completion_data* cd) { auto sa = ep.to_sockaddr(); std::lock_guard lck{m_sqe_mtx}; @@ -176,7 +199,7 @@ namespace asyncpp::io::detail { bool io_engine_uring::enqueue_recv_from(socket_handle_t socket, void* buf, size_t len, endpoint* source, completion_data* cd) { - auto* info = m_state_allocator.create(); + auto* info = cd->es_init(); info->hdr.msg_name = &info->sockaddr; info->hdr.msg_namelen = sizeof(info->sockaddr); info->hdr.msg_iov = &info->data; @@ -185,8 +208,6 @@ namespace asyncpp::io::detail { info->data.iov_len = len; info->real_endpoint = source; - cd->engine_state = info; - std::lock_guard lck{m_sqe_mtx}; struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); io_uring_prep_recvmsg(sqe, socket, &info->hdr, 0); @@ -199,7 +220,7 @@ namespace asyncpp::io::detail { bool io_engine_uring::enqueue_send_to(socket_handle_t socket, const void* buf, size_t len, endpoint dst, completion_data* cd) { auto addr = dst.to_sockaddr(); - auto* info = m_state_allocator.create(); + auto* info = cd->es_init(); info->hdr.msg_name = &info->sockaddr; info->hdr.msg_namelen = addr.second; info->hdr.msg_iov = &info->data; @@ -207,9 +228,6 @@ namespace asyncpp::io::detail { info->sockaddr = addr.first; info->data.iov_base = const_cast(buf); info->data.iov_len = len; - info->real_endpoint = nullptr; - - cd->engine_state = info; std::lock_guard lck{m_sqe_mtx}; struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); @@ -220,13 +238,14 @@ namespace asyncpp::io::detail { return false; } - bool io_engine_uring::enqueue_readv(file_handle_t fd, void* buf, size_t len, off_t offset, completion_data* cd) { - auto* info = m_state_allocator.create(); + void io_engine_uring::file_register(file_handle_t fd) {} + + void io_engine_uring::file_release(file_handle_t fd) {} + + bool io_engine_uring::enqueue_readv(file_handle_t fd, void* buf, size_t len, uint64_t offset, completion_data* cd) { + auto* info = cd->es_init(); info->data.iov_base = buf; info->data.iov_len = len; - info->real_endpoint = nullptr; - - cd->engine_state = info; std::lock_guard lck{m_sqe_mtx}; struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); @@ -237,14 +256,11 @@ namespace asyncpp::io::detail { return false; } - bool io_engine_uring::enqueue_writev(file_handle_t fd, const void* buf, size_t len, off_t offset, + bool io_engine_uring::enqueue_writev(file_handle_t fd, const void* buf, size_t len, uint64_t offset, completion_data* cd) { - auto* info = m_state_allocator.create(); + auto* info = cd->es_init(); info->data.iov_base = const_cast(buf); info->data.iov_len = len; - info->real_endpoint = nullptr; - - cd->engine_state = info; std::lock_guard lck{m_sqe_mtx}; struct io_uring_sqe* sqe = io_uring_get_sqe(&m_ring); diff --git a/src/socket.cpp b/src/socket.cpp index 5f7d292..3ec52eb 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -1,52 +1,9 @@ #include -#include - -#ifndef _WIN32 -#include -#include -#include -#include -#else -#include -#include -#endif - -namespace { - - std::system_error sys_error(int code) { - return std::system_error(std::make_error_code(static_cast(code))); - } - -} // namespace - namespace asyncpp::io { socket socket::create_tcp(io_service& io, address_type addrtype) { - int domain = -1; - switch (addrtype) { - case address_type::ipv4: domain = AF_INET; break; - case address_type::ipv6: domain = AF_INET6; break; - case address_type::uds: domain = AF_UNIX; break; - } - if (domain == -1) throw sys_error(ENOTSUP); -#ifndef __APPLE__ - auto fd = ::socket(domain, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0); - if (fd < 0) throw sys_error(errno); -#else - auto fd = ::socket(domain, SOCK_STREAM, 0); - if (fd < 0) throw sys_error(errno); - int flags = fcntl(fd, F_GETFL, 0); - if (flags < 0 || fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0 || fcntl(fd, F_SETFD, FD_CLOEXEC) < 0) { - close(fd); - throw std::system_error(errno, std::system_category(), "fcntl failed"); - } -#endif - if (addrtype == address_type::ipv6) { - int opt = 0; - if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) - throw std::system_error(errno, std::system_category(), "setsockopt failed"); - } + auto fd = io.engine()->socket_create(addrtype, detail::io_engine::socket_type::stream); return socket(&io, fd); } @@ -60,30 +17,7 @@ namespace asyncpp::io { } socket socket::create_udp(io_service& io, address_type addrtype) { - int domain = -1; - switch (addrtype) { - case address_type::ipv4: domain = AF_INET; break; - case address_type::ipv6: domain = AF_INET6; break; - case address_type::uds: domain = AF_UNIX; break; - } - if (domain == -1) throw sys_error(ENOTSUP); -#ifndef __APPLE__ - auto fd = ::socket(domain, SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0); - if (fd < 0) throw sys_error(errno); -#else - auto fd = ::socket(domain, SOCK_DGRAM, 0); - if (fd < 0) throw sys_error(errno); - int flags = fcntl(fd, F_GETFL, 0); - if (flags < 0 || fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0 || fcntl(fd, F_SETFD, FD_CLOEXEC) < 0) { - close(fd); - throw std::system_error(errno, std::system_category(), "fcntl failed"); - } -#endif - if (addrtype == address_type::ipv6) { - int opt = 0; - if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) < 0) - throw std::system_error(errno, std::system_category(), "setsockopt failed"); - } + auto fd = io.engine()->socket_create(addrtype, detail::io_engine::socket_type::dgram); return socket(&io, fd); } @@ -101,99 +35,49 @@ namespace asyncpp::io { socket socket::from_fd(io_service& io, detail::io_engine::socket_handle_t fd) { if (fd < 0) throw std::logic_error("invalid socket"); -#ifdef _WIN32 - unsigned long mode = blocking ? 0 : 1; - if (ioctlsocket(fd, FIONBIO, &mode) != SOCKET_ERROR) - throw std::system_error(std::make_error_code(std::errc::io_error), "ioctlsocket failed"); -#else - int flags = fcntl(fd, F_GETFL, 0); - if (flags == -1) throw sys_error(errno); - if ((flags & O_NONBLOCK) != O_NONBLOCK && fcntl(fd, F_SETFL, flags | O_NONBLOCK) != 0) throw sys_error(errno); -#endif + io.engine()->socket_register(fd); socket sock(&io, fd); sock.update_endpoint_info(); return sock; } -#ifndef _WIN32 std::pair socket::connected_pair_tcp(io_service& io, address_type addrtype) { - int domain = -1; - switch (addrtype) { - case address_type::ipv4: domain = AF_INET; break; - case address_type::ipv6: domain = AF_INET6; break; - case address_type::uds: domain = AF_UNIX; break; - } - if (domain == -1) throw sys_error(ENOTSUP); - - int socks[2]; -#ifndef __APPLE__ - if (socketpair(domain, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0, socks) != 0) throw sys_error(errno); -#else - if (socketpair(domain, SOCK_STREAM, 0, socks) != 0) throw sys_error(errno); - int flags0 = fcntl(socks[0], F_GETFL, 0); - int flags1 = fcntl(socks[1], F_GETFL, 0); - if (flags0 < 0 || flags1 < 0 || // - fcntl(socks[0], F_SETFL, flags0 | O_NONBLOCK) < 0 || fcntl(socks[1], F_SETFL, flags1 | O_NONBLOCK) < 0 || // - fcntl(socks[0], F_SETFD, FD_CLOEXEC) < 0 || fcntl(socks[1], F_SETFD, FD_CLOEXEC) < 0) { - close(socks[0]); - close(socks[1]); - throw std::system_error(errno, std::system_category(), "pipe failed"); - } -#endif - std::pair res{socket(&io, socks[0]), socket(&io, socks[1])}; + auto socks = io.engine()->socket_create_connected_pair(addrtype, detail::io_engine::socket_type::stream); + std::pair res{socket(&io, socks.first), socket(&io, socks.second)}; res.first.update_endpoint_info(); res.second.update_endpoint_info(); return res; } std::pair socket::connected_pair_udp(io_service& io, address_type addrtype) { - int domain = -1; - switch (addrtype) { - case address_type::ipv4: domain = AF_INET; break; - case address_type::ipv6: domain = AF_INET6; break; - case address_type::uds: domain = AF_UNIX; break; - } - if (domain == -1) throw sys_error(ENOTSUP); - - int socks[2]; -#ifndef __APPLE__ - if (socketpair(domain, SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0, socks) != 0) throw sys_error(errno); -#else - if (socketpair(domain, SOCK_DGRAM, 0, socks) != 0) throw sys_error(errno); - int flags0 = fcntl(socks[0], F_GETFL, 0); - int flags1 = fcntl(socks[1], F_GETFL, 0); - if (flags0 < 0 || flags1 < 0 || // - fcntl(socks[0], F_SETFL, flags0 | O_NONBLOCK) < 0 || fcntl(socks[1], F_SETFL, flags1 | O_NONBLOCK) < 0 || // - fcntl(socks[0], F_SETFD, FD_CLOEXEC) < 0 || fcntl(socks[1], F_SETFD, FD_CLOEXEC) < 0) { - close(socks[0]); - close(socks[1]); - throw std::system_error(errno, std::system_category(), "pipe failed"); - } -#endif - return {socket(&io, socks[0]), socket(&io, socks[1])}; + auto socks = io.engine()->socket_create_connected_pair(addrtype, detail::io_engine::socket_type::stream); + std::pair res{socket(&io, socks.first), socket(&io, socks.second)}; + res.first.update_endpoint_info(); + res.second.update_endpoint_info(); + return res; } -#endif - socket::socket(io_service* io, int fd) noexcept : m_io{io}, m_fd{fd}, m_remote_ep{}, m_local_ep{} {} + socket::socket(io_service* io, detail::io_engine::socket_handle_t fd) noexcept + : m_io{io}, m_fd{fd}, m_remote_ep{}, m_local_ep{} {} socket::socket(socket&& other) noexcept : m_io{other.m_io}, m_fd{other.m_fd}, m_remote_ep{other.m_remote_ep}, m_local_ep{other.m_local_ep} { other.m_io = nullptr; - other.m_fd = -1; + other.m_fd = detail::io_engine::invalid_socket_handle; other.m_local_ep = {}; other.m_remote_ep = {}; } socket& socket::operator=(socket&& other) noexcept { - if (m_fd >= 0) { - close(m_fd); + if (m_fd != detail::io_engine::invalid_socket_handle) { + m_io->engine()->socket_close(m_fd); // TODO: Log errors returned from close ? - m_fd = -1; + m_fd = detail::io_engine::invalid_socket_handle; } m_io = other.m_io; other.m_io = nullptr; m_fd = other.m_fd; - other.m_fd = -1; + other.m_fd = detail::io_engine::invalid_socket_handle; m_local_ep = other.m_local_ep; other.m_local_ep = {}; m_remote_ep = other.m_remote_ep; @@ -203,68 +87,43 @@ namespace asyncpp::io { } socket::~socket() { - if (m_fd >= 0) { - close(m_fd); + if (m_fd != detail::io_engine::invalid_socket_handle) { + m_io->engine()->socket_close(m_fd); // TODO: Log errors returned from close ? - m_fd = -1; + m_fd = detail::io_engine::invalid_socket_handle; } } void socket::bind(const endpoint& ep) { - if (m_fd < 0) throw std::logic_error("invalid socket"); - - auto sa = ep.to_sockaddr(); - auto res = ::bind(m_fd, reinterpret_cast(&sa.first), sa.second); - if (res < 0) throw sys_error(errno); - + if (m_fd == detail::io_engine::invalid_socket_handle) throw std::logic_error("invalid socket"); + m_io->engine()->socket_bind(m_fd, ep); update_endpoint_info(); } void socket::listen(std::uint32_t backlog) { - if (m_fd < 0) throw std::logic_error("invalid socket"); - - if (backlog == 0) backlog = 20; - auto res = ::listen(m_fd, backlog); - if (res < 0) throw sys_error(errno); + if (m_fd == detail::io_engine::invalid_socket_handle) throw std::logic_error("invalid socket"); + m_io->engine()->socket_listen(m_fd, backlog); } void socket::allow_broadcast(bool enable) { - if (m_fd < 0) throw std::logic_error("invalid socket"); - - int opt = enable ? 1 : 0; - auto res = setsockopt(m_fd, SOL_SOCKET, SO_BROADCAST, &opt, sizeof(opt)); - if (res < 0) throw sys_error(errno); + if (m_fd == detail::io_engine::invalid_socket_handle) throw std::logic_error("invalid socket"); + m_io->engine()->socket_enable_broadcast(m_fd, enable); } void socket::close_send() { - if (m_fd < 0) throw std::logic_error("invalid socket"); - - auto res = ::shutdown(m_fd, SHUT_WR); - if (res < 0 && errno != ENOTCONN) throw sys_error(errno); + if (m_fd == detail::io_engine::invalid_socket_handle) throw std::logic_error("invalid socket"); + m_io->engine()->socket_shutdown(m_fd, false, true); } void socket::close_recv() { - if (m_fd < 0) throw std::logic_error("invalid socket"); - - auto res = ::shutdown(m_fd, SHUT_RD); - if (res < 0 && errno != ENOTCONN) throw sys_error(errno); + if (m_fd == detail::io_engine::invalid_socket_handle) throw std::logic_error("invalid socket"); + m_io->engine()->socket_shutdown(m_fd, true, false); } void socket::update_endpoint_info() { - sockaddr_storage sa; - socklen_t sa_size = sizeof(sa); - auto res = getpeername(m_fd, reinterpret_cast(&sa), &sa_size); - if (res >= 0) - m_remote_ep = endpoint(sa, sa_size); - else if (res < 0 && errno != ENOTCONN) - throw sys_error(errno); - else - m_remote_ep = {}; - - sa_size = sizeof(sa); - res = getsockname(m_fd, reinterpret_cast(&sa), &sa_size); - if (res < 0) throw sys_error(errno); - m_local_ep = endpoint(sa, sa_size); + auto io = m_io->engine(); + m_remote_ep = io->socket_remote_endpoint(m_fd); + m_local_ep = io->socket_local_endpoint(m_fd); } } // namespace asyncpp::io diff --git a/test/address.cpp b/test/address.cpp index 14ab8b9..79231ab 100644 --- a/test/address.cpp +++ b/test/address.cpp @@ -133,6 +133,7 @@ TEST(ASYNCPP_IO, IPv6ToString) { ASSERT_EQ(ipv6_address(0x0102030400000000, 0x090A0B0C00000000).to_string(), "102:304::90a:b0c:0:0"); } +#ifndef _WIN32 TEST(ASYNCPP_IO, UDSParse) { ASSERT_EQ(uds_address::parse(std::string_view("\0", 1)), std::nullopt); ASSERT_EQ(uds_address::parse(std::string_view("@")), std::nullopt); @@ -174,3 +175,4 @@ TEST(ASYNCPP_IO, UDSTypes) { ASSERT_FALSE(uds_address("").is_abstract()); ASSERT_TRUE(uds_address("").is_unnamed()); } +#endif diff --git a/test/socket.cpp b/test/socket.cpp index c805e7e..b8f74cf 100644 --- a/test/socket.cpp +++ b/test/socket.cpp @@ -12,7 +12,7 @@ using asyncpp::task; asyncpp::stop_token timeout(std::chrono::nanoseconds ts) { asyncpp::stop_source source; - asyncpp::timer::get_default().schedule([source](bool) { source.request_stop(); }, ts); + asyncpp::timer::get_default().schedule([source](bool) mutable { source.request_stop(); }, ts); return source.get_token(); } @@ -68,7 +68,7 @@ TEST(ASYNCPP_IO, SocketSelf) { }(service, server, st)); // Connect to said server auto client = socket::create_tcp(service, server.local_endpoint().type()); - co_await client.connect(server.local_endpoint(), st); + co_await client.connect(endpoint(ipv4_address::loopback(), server.local_endpoint().ipv4().port()), st); // and read until connection is closed while (true) { char buf[128]; @@ -114,17 +114,20 @@ TEST(ASYNCPP_IO, SocketValid) { ASSERT_TRUE(sock2.valid()); auto fd = sock2.release(); ASSERT_FALSE(sock2); - close(fd); + service->engine()->socket_close(fd); } -#ifdef __linux__ TEST(ASYNCPP_IO, SocketPair) { io_service service; std::string received; asyncpp::async_launch_scope scope; scope.invoke([&service, &received]() -> task<> { auto stop = timeout(std::chrono::seconds(2)); +#ifdef _WIN32 + auto pair = socket::connected_pair_tcp(service, address_type::ipv4); +#else auto pair = socket::connected_pair_tcp(service, address_type::uds); +#endif co_await pair.first.send("Hello", 5, stop); pair.first.close_send(); while (true) { @@ -142,4 +145,3 @@ TEST(ASYNCPP_IO, SocketPair) { ASSERT_EQ(received.size(), 5); ASSERT_EQ(received, "Hello"); } -#endif diff --git a/test/tls.cpp b/test/tls.cpp index 4100b05..604db1c 100644 --- a/test/tls.cpp +++ b/test/tls.cpp @@ -20,6 +20,9 @@ TEST(ASYNCPP_IO, TLSRoundtrip) { std::cout.sync_with_stdio(true); // Generate cert if missing if (!std::filesystem::exists("ssl.crt") || !std::filesystem::exists("ssl.key")) { +#ifdef _WIN32 + GTEST_SKIP() << "Can not generate certs on windows"; +#endif std::cout << "Generating temporary cert..." << std::endl; system("openssl req -x509 -newkey rsa:2048 -keyout ssl.key -out ssl.crt -sha256 -days 2 -nodes -subj " "\"/C=XX/ST=StateName/L=SomeCity/O=ASYNCPP/OU=ASYNCPP-TEST/CN=server1\""); @@ -113,7 +116,11 @@ TEST(ASYNCPP_IO, TLSRoundtrip) { TEST(ASYNCPP_IO, TLSClient) { std::cout.sync_with_stdio(true); tls::context ctx_client(tls::method::tls, tls::mode::client); - //ctx_client.set_verify(tls::verify_mode::none); +#ifdef _WIN32 + // I am too lazy to figure out the cert locations and + // we only want to test interaction with async io anyway + ctx_client.set_verify(tls::verify_mode::none); +#endif ctx_client.load_verify_locations("", "/etc/ssl/certs/"); ctx_client.set_alpn_protos({"http/1.1"}); tls::session ssl_client(ctx_client);