From db34ef29fd7dacd11180454c581ca906e6608d7f Mon Sep 17 00:00:00 2001 From: momo5502 Date: Thu, 7 Nov 2024 20:25:20 +0100 Subject: [PATCH] Support UDP sending --- src/common/empty.cpp | 0 src/common/network/address.cpp | 383 ++++++++++++++++++ src/common/network/address.hpp | 105 +++++ src/common/network/socket.cpp | 207 ++++++++++ src/common/network/socket.hpp | 60 +++ src/emulator/serialization.hpp | 34 +- src/windows-emulator/devices/afd_endpoint.cpp | 109 ++++- 7 files changed, 883 insertions(+), 15 deletions(-) delete mode 100644 src/common/empty.cpp create mode 100644 src/common/network/address.cpp create mode 100644 src/common/network/address.hpp create mode 100644 src/common/network/socket.cpp create mode 100644 src/common/network/socket.hpp diff --git a/src/common/empty.cpp b/src/common/empty.cpp deleted file mode 100644 index e69de29..0000000 diff --git a/src/common/network/address.cpp b/src/common/network/address.cpp new file mode 100644 index 0000000..837a638 --- /dev/null +++ b/src/common/network/address.cpp @@ -0,0 +1,383 @@ +#include "address.hpp" + +#include + +#include "../utils/finally.hpp" + +using namespace std::literals; + +namespace network +{ + void initialize_wsa() + { +#ifdef _WIN32 + static struct wsa_initializer + { + public: + wsa_initializer() + { + WSADATA wsa_data; + if (WSAStartup(MAKEWORD(2, 2), &wsa_data)) + { + throw std::runtime_error("Unable to initialize WSA"); + } + } + + ~wsa_initializer() + { + WSACleanup(); + } + } _; +#endif + } + + address::address() + { + initialize_wsa(); + ZeroMemory(&this->storage_, this->get_max_size()); + + this->address_.sa_family = AF_UNSPEC; + } + + address::address(const std::string& addr, const std::optional& family) + : address() + { + this->parse(addr, family); + } + + address::address(const sockaddr_in6& addr) + : address() + { + this->address6_ = addr; + } + + address::address(const sockaddr_in& addr) + : address() + { + this->address4_ = addr; + } + + address::address(const sockaddr* addr, const int length) + : address() + { + this->set_address(addr, length); + } + + void address::set_ipv4(const uint32_t ip) + { + in_addr addr{}; + addr.s_addr = ip; + this->set_ipv4(addr); + } + + bool address::operator==(const address& obj) const + { + if (this->address_.sa_family != obj.address_.sa_family) + { + return false; + } + + if (this->get_port() != obj.get_port()) + { + return false; + } + + if (this->address_.sa_family == AF_INET) + { + return this->address4_.sin_addr.s_addr == obj.address4_.sin_addr.s_addr; + } + else if (this->address_.sa_family == AF_INET6) + { + return !memcmp(this->address6_.sin6_addr.s6_addr, obj.address6_.sin6_addr.s6_addr, + sizeof(obj.address6_.sin6_addr.s6_addr)); + } + + return false; + } + + void address::set_ipv4(const in_addr& addr) + { + ZeroMemory(&this->address4_, sizeof(this->address4_)); + this->address4_.sin_family = AF_INET; + this->address4_.sin_addr = addr; + } + + void address::set_ipv6(const in6_addr& addr) + { + ZeroMemory(&this->address6_, sizeof(this->address6_)); + this->address6_.sin6_family = AF_INET6; + this->address6_.sin6_addr = addr; + } + + void address::set_address(const sockaddr* addr, const int length) + { + if (static_cast(length) >= sizeof(sockaddr_in) && addr->sa_family == AF_INET) + { + this->address4_ = *reinterpret_cast(addr); + } + else if (static_cast(length) == sizeof(sockaddr_in6) && addr->sa_family == AF_INET6) + { + this->address6_ = *reinterpret_cast(addr); + } + else + { + throw std::runtime_error("Invalid network address"); + } + } + + void address::set_port(const unsigned short port) + { + switch (this->address_.sa_family) + { + case AF_INET: + this->address4_.sin_port = htons(port); + break; + case AF_INET6: + this->address6_.sin6_port = htons(port); + break; + default: + throw std::runtime_error("Invalid address family"); + } + } + + unsigned short address::get_port() const + { + switch (this->address_.sa_family) + { + case AF_INET: + return ntohs(this->address4_.sin_port); + case AF_INET6: + return ntohs(this->address6_.sin6_port); + default: + return 0; + } + } + + std::string address::to_string() const + { + char buffer[1000] = {0}; + std::string addr; + + switch (this->address_.sa_family) + { + case AF_INET: + inet_ntop(this->address_.sa_family, &this->address4_.sin_addr, buffer, sizeof(buffer)); + addr = std::string(buffer); + break; + case AF_INET6: + inet_ntop(this->address_.sa_family, &this->address6_.sin6_addr, buffer, sizeof(buffer)); + addr = "[" + std::string(buffer) + "]"; + break; + default: + buffer[0] = '?'; + buffer[1] = 0; + addr = std::string(buffer); + break; + } + + return addr + ":"s + std::to_string(this->get_port()); + } + + bool address::is_local() const + { + if (this->address_.sa_family != AF_INET) + { + return false; + } + + // According to: https://en.wikipedia.org/wiki/Private_network + + uint8_t bytes[4]; + *reinterpret_cast(&bytes) = this->address4_.sin_addr.s_addr; + + // 10.X.X.X + if (bytes[0] == 10) + { + return true; + } + + // 192.168.X.X + if (bytes[0] == 192 + && bytes[1] == 168) + { + return true; + } + + // 172.16.X.X - 172.31.X.X + if (bytes[0] == 172 + && bytes[1] >= 16 + && bytes[1] < 32) + { + return true; + } + + // 127.0.0.1 + if (this->address4_.sin_addr.s_addr == 0x0100007F) + { + return true; + } + + return false; + } + + sockaddr& address::get_addr() + { + return this->address_; + } + + const sockaddr& address::get_addr() const + { + return this->address_; + } + + sockaddr_in& address::get_in_addr() + { + return this->address4_; + } + + sockaddr_in6& address::get_in6_addr() + { + return this->address6_; + } + + const sockaddr_in& address::get_in_addr() const + { + return this->address4_; + } + + const sockaddr_in6& address::get_in6_addr() const + { + return this->address6_; + } + + int address::get_size() const + { + switch (this->address_.sa_family) + { + case AF_INET: + return static_cast(sizeof(this->address4_)); + case AF_INET6: + return static_cast(sizeof(this->address6_)); + default: + return static_cast(sizeof(this->address_)); + } + } + + int address::get_max_size() const + { + const auto s = sizeof(this->address_); + const auto s4 = sizeof(this->address4_); + const auto s6 = sizeof(this->address6_); + const auto sstore = sizeof(this->storage_); + const auto max_size = std::max(sstore, std::max(s, std::max(s4, s6))); + static_assert(max_size == sstore); + + return max_size; + } + + bool address::is_ipv4() const + { + return this->address_.sa_family == AF_INET; + } + + bool address::is_ipv6() const + { + return this->address_.sa_family == AF_INET6; + } + + bool address::is_supported() const + { + return is_ipv4() || is_ipv6(); + } + + void address::parse(std::string addr, const std::optional& family) + { + std::optional port_value{}; + + const auto pos = addr.find_last_of(':'); + if (pos != std::string::npos) + { + auto port = addr.substr(pos + 1); + port_value = uint16_t(atoi(port.data())); + addr = addr.substr(0, pos); + } + + this->resolve(addr, family); + + if (port_value) + { + this->set_port(*port_value); + } + } + + void address::resolve(const std::string& hostname, const std::optional& family) + { + const auto port = this->get_port(); + auto port_reset_action = utils::finally([this, port]() + { + this->set_port(port); + }); + + const auto result = resolve_multiple(hostname); + for (const auto& addr : result) + { + if (addr.is_supported() && (!family || addr.get_addr().sa_family == *family)) + { + this->set_address(&addr.get_addr(), addr.get_size()); + return; + } + } + + port_reset_action.cancel(); + throw std::runtime_error{"Unable to resolve hostname: " + hostname}; + } + + std::vector
address::resolve_multiple(const std::string& hostname) + { + std::vector
results{}; + + addrinfo* result = nullptr; + if (!getaddrinfo(hostname.data(), nullptr, nullptr, &result)) + { + const auto _2 = utils::finally([&result]() + { + freeaddrinfo(result); + }); + + for (auto* i = result; i; i = i->ai_next) + { + if (i->ai_family == AF_INET || i->ai_family == AF_INET6) + { + address a{}; + a.set_address(i->ai_addr, static_cast(i->ai_addrlen)); + results.emplace_back(std::move(a)); + } + } + } + + return results; + } +} + +std::size_t std::hash::operator()(const network::address& a) const noexcept +{ + const uint32_t family = a.get_addr().sa_family; + const uint32_t port = a.get_port(); + + std::size_t hash = std::hash{}(family); + hash ^= std::hash{}(port); + switch (a.get_addr().sa_family) + { + case AF_INET: + hash ^= std::hash{}(a.get_in_addr().sin_addr.s_addr); + break; + case AF_INET6: + hash ^= std::hash{}(std::string_view{ + reinterpret_cast(a.get_in6_addr().sin6_addr.s6_addr), + sizeof(a.get_in6_addr().sin6_addr.s6_addr) + }); + break; + } + + return hash; +} diff --git a/src/common/network/address.hpp b/src/common/network/address.hpp new file mode 100644 index 0000000..d7a5ab0 --- /dev/null +++ b/src/common/network/address.hpp @@ -0,0 +1,105 @@ +#pragma once + +#if _WIN32 +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_NO_POSIX_ERROR_CODES +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN +#include +#include +#include +#else + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define ZeroMemory(x, y) memset(x, 0, y) + +#endif + +#include +#include +#include + +#ifdef _WIN32 +#pragma comment(lib, "ws2_32.lib") +#endif + +namespace network +{ + void initialize_wsa(); + + class address + { + public: + address(); + address(const std::string& addr, const std::optional& family = {}); + address(const sockaddr_in& addr); + address(const sockaddr_in6& addr); + address(const sockaddr* addr, int length); + + void set_ipv4(uint32_t ip); + void set_ipv4(const in_addr& addr); + void set_ipv6(const in6_addr& addr); + void set_address(const sockaddr* addr, int length); + + void set_port(unsigned short port); + [[nodiscard]] unsigned short get_port() const; + + sockaddr& get_addr(); + sockaddr_in& get_in_addr(); + sockaddr_in6& get_in6_addr(); + + const sockaddr& get_addr() const; + const sockaddr_in& get_in_addr() const; + const sockaddr_in6& get_in6_addr() const; + + int get_size() const; + int get_max_size() const; + + bool is_ipv4() const; + bool is_ipv6() const; + bool is_supported() const; + + [[nodiscard]] bool is_local() const; + [[nodiscard]] std::string to_string() const; + + bool operator==(const address& obj) const; + + bool operator!=(const address& obj) const + { + return !(*this == obj); + } + + static std::vector
resolve_multiple(const std::string& hostname); + + private: + union + { + sockaddr address_; + sockaddr_in address4_; + sockaddr_in6 address6_; + sockaddr_storage storage_; + }; + + void parse(std::string addr, const std::optional& family = {}); + void resolve(const std::string& hostname, const std::optional& family = {}); + }; +} + +namespace std +{ + template <> + struct hash + { + std::size_t operator()(const network::address& a) const noexcept; + }; +} diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp new file mode 100644 index 0000000..2d16217 --- /dev/null +++ b/src/common/network/socket.cpp @@ -0,0 +1,207 @@ +#include "socket.hpp" + +#include + +#ifdef _WIN32 +#define poll WSAPoll +#define SOCK_WOULDBLOCK WSAEWOULDBLOCK +#else +#define SOCK_WOULDBLOCK EWOULDBLOCK +#endif + +using namespace std::literals; + +namespace network +{ + socket::socket(const int af) + : address_family_(af) + { + initialize_wsa(); + this->socket_ = ::socket(af, SOCK_DGRAM, IPPROTO_UDP); + + if (af == AF_INET6) + { + int i = 1; + setsockopt(this->socket_, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&i), + static_cast(sizeof(i))); + } + } + + socket::~socket() + { + if (this->socket_ != INVALID_SOCKET) + { +#ifdef _WIN32 + closesocket(this->socket_); +#else + close(this->socket_); +#endif + } + } + + socket::socket(socket&& obj) noexcept + { + this->operator=(std::move(obj)); + } + + socket& socket::operator=(socket&& obj) noexcept + { + if (this != &obj) + { + this->~socket(); + this->socket_ = obj.socket_; + this->port_ = obj.port_; + this->address_family_ = obj.address_family_; + + obj.socket_ = INVALID_SOCKET; + obj.address_family_ = AF_UNSPEC; + } + + return *this; + } + + bool socket::bind_port(const address& target) + { + const auto result = bind(this->socket_, &target.get_addr(), target.get_size()) == 0; + if (result) + { + this->port_ = target.get_port(); + } + + return result; + } + + bool socket::send(const address& target, const void* data, const size_t size) const + { + const int res = sendto(this->socket_, static_cast(data), static_cast(size), 0, + &target.get_addr(), + target.get_size()); + return res == static_cast(size); + } + + bool socket::send(const address& target, const std::string& data) const + { + return this->send(target, data.data(), data.size()); + } + + bool socket::receive(address& source, std::string& data) const + { + char buffer[0x2000]; + socklen_t len = source.get_max_size(); + + const auto result = recvfrom(this->socket_, buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), + &len); + if (result == SOCKET_ERROR) + { + return false; + } + + data.assign(buffer, buffer + result); + return true; + } + + bool socket::set_blocking(const bool blocking) + { +#ifdef _WIN32 + unsigned long mode = blocking ? 0 : 1; + return ioctlsocket(this->socket_, FIONBIO, &mode) == 0; +#else + int flags = fcntl(this->socket_, F_GETFL, 0); + if (flags == -1) return false; + flags = blocking ? (flags & ~O_NONBLOCK) : (flags | O_NONBLOCK); + return fcntl(this->socket_, F_SETFL, flags) == 0; +#endif + } + + bool socket::sleep(const std::chrono::milliseconds timeout) const + { + /*fd_set fdr; + FD_ZERO(&fdr); + FD_SET(this->socket_, &fdr); + + const auto msec = timeout.count(); + + timeval tv{}; + tv.tv_sec = static_cast(msec / 1000ll); + tv.tv_usec = static_cast((msec % 1000) * 1000); + + const auto retval = select(static_cast(this->socket_) + 1, &fdr, nullptr, nullptr, &tv); + if (retval == SOCKET_ERROR) + { + std::this_thread::sleep_for(1ms); + return socket_is_ready; + } + + if (retval > 0) + { + return socket_is_ready; + } + + return !socket_is_ready;*/ + + std::vector sockets{}; + sockets.push_back(this); + + return sleep_sockets(sockets, timeout); + } + + bool socket::sleep_until(const std::chrono::high_resolution_clock::time_point time_point) const + { + const auto duration = time_point - std::chrono::high_resolution_clock::now(); + return this->sleep(std::chrono::duration_cast(duration)); + } + + SOCKET socket::get_socket() const + { + return this->socket_; + } + + uint16_t socket::get_port() const + { + return this->port_; + } + + int socket::get_address_family() const + { + return this->address_family_; + } + + bool socket::sleep_sockets(const std::span& sockets, const std::chrono::milliseconds timeout) + { + std::vector pfds{}; + pfds.resize(sockets.size()); + + for (size_t i = 0; i < sockets.size(); ++i) + { + auto& pfd = pfds.at(i); + const auto& socket = sockets[i]; + + pfd.fd = socket->get_socket(); + pfd.events = POLLIN; + pfd.revents = 0; + } + + const auto retval = poll(pfds.data(), static_cast(pfds.size()), + static_cast(timeout.count())); + + if (retval == SOCKET_ERROR) + { + std::this_thread::sleep_for(1ms); + return socket_is_ready; + } + + if (retval > 0) + { + return socket_is_ready; + } + + return !socket_is_ready; + } + + bool socket::sleep_sockets_until(const std::span& sockets, + const std::chrono::high_resolution_clock::time_point time_point) + { + const auto duration = time_point - std::chrono::high_resolution_clock::now(); + return sleep_sockets(sockets, std::chrono::duration_cast(duration)); + } +} diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp new file mode 100644 index 0000000..3eda5cb --- /dev/null +++ b/src/common/network/socket.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "address.hpp" + +#include +#include + +#ifdef _WIN32 +using socklen_t = int; +#define GET_SOCKET_ERROR() (WSAGetLastError()) +#else + using SOCKET = int; +#define INVALID_SOCKET (SOCKET)(~0) +#define SOCKET_ERROR (-1) +#define GET_SOCKET_ERROR() (errno) +#endif + +namespace network +{ + class socket + { + public: + socket() = default; + + socket(int af); + ~socket(); + + socket(const socket& obj) = delete; + socket& operator=(const socket& obj) = delete; + + socket(socket&& obj) noexcept; + socket& operator=(socket&& obj) noexcept; + + bool bind_port(const address& target); + + [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; + [[maybe_unused]] bool send(const address& target, const std::string& data) const; + bool receive(address& source, std::string& data) const; + + bool set_blocking(bool blocking); + + static constexpr bool socket_is_ready = true; + bool sleep(std::chrono::milliseconds timeout) const; + bool sleep_until(std::chrono::high_resolution_clock::time_point time_point) const; + + SOCKET get_socket() const; + uint16_t get_port() const; + + int get_address_family() const; + + static bool sleep_sockets(const std::span& sockets, std::chrono::milliseconds timeout); + static bool sleep_sockets_until(const std::span& sockets, + std::chrono::high_resolution_clock::time_point time_point); + + private: + int address_family_{AF_UNSPEC}; + uint16_t port_ = 0; + SOCKET socket_ = INVALID_SOCKET; + }; +} diff --git a/src/emulator/serialization.hpp b/src/emulator/serialization.hpp index 0b6f916..f40625e 100644 --- a/src/emulator/serialization.hpp +++ b/src/emulator/serialization.hpp @@ -63,15 +63,16 @@ namespace utils { public: template - buffer_deserializer(const std::span buffer) - : buffer_(reinterpret_cast(buffer.data()), buffer.size() * sizeof(T)) + buffer_deserializer(const std::span buffer, bool no_debugging = false) + : no_debugging_(no_debugging) + , buffer_(reinterpret_cast(buffer.data()), buffer.size() * sizeof(T)) { static_assert(std::is_trivially_copyable_v, "Type must be trivially copyable"); } template - buffer_deserializer(const std::vector& buffer) - : buffer_deserializer(std::span(buffer)) + buffer_deserializer(const std::vector& buffer, bool no_debugging = false) + : buffer_deserializer(std::span(buffer), no_debugging) { } @@ -79,6 +80,7 @@ namespace utils { #ifndef NDEBUG const uint64_t real_old_size = this->offset_; + (void)real_old_size; #endif if (this->offset_ + length > this->buffer_.size()) @@ -91,19 +93,22 @@ namespace utils #ifndef NDEBUG - uint64_t old_size{}; - if (this->offset_ + sizeof(old_size) > this->buffer_.size()) + if (!this->no_debugging_) { - throw std::runtime_error("Out of bounds read from byte buffer"); - } + uint64_t old_size{}; + if (this->offset_ + sizeof(old_size) > this->buffer_.size()) + { + throw std::runtime_error("Out of bounds read from byte buffer"); + } - memcpy(&old_size, this->buffer_.data() + this->offset_, sizeof(old_size)); - if (old_size != real_old_size) - { - throw std::runtime_error("Reading from serialized buffer mismatches written data!"); - } + memcpy(&old_size, this->buffer_.data() + this->offset_, sizeof(old_size)); + if (old_size != real_old_size) + { + throw std::runtime_error("Reading from serialized buffer mismatches written data!"); + } - this->offset_ += sizeof(old_size); + this->offset_ += sizeof(old_size); + } #endif return result; @@ -276,6 +281,7 @@ namespace utils } private: + bool no_debugging_{false}; size_t offset_{0}; std::span buffer_{}; std::unordered_map> factories_{}; diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index 3657cbe..1c2f301 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -1,14 +1,121 @@ #include "afd_endpoint.hpp" -#include "windows-emulator/windows_emulator.hpp" +#include "../windows_emulator.hpp" +#include +#include + + +typedef LONG TDI_STATUS; +typedef PVOID CONNECTION_CONTEXT; + +typedef struct _TDI_CONNECTION_INFORMATION +{ + LONG UserDataLength; + PVOID UserData; + LONG OptionsLength; + PVOID Options; + LONG RemoteAddressLength; + PVOID RemoteAddress; +} TDI_CONNECTION_INFORMATION, *PTDI_CONNECTION_INFORMATION; + +typedef struct _TDI_REQUEST +{ + union + { + HANDLE AddressHandle; + CONNECTION_CONTEXT ConnectionContext; + HANDLE ControlChannel; + } Handle; + + PVOID RequestNotifyObject; + PVOID RequestContext; + TDI_STATUS TdiStatus; +} TDI_REQUEST, *PTDI_REQUEST; + +typedef struct _TDI_REQUEST_SEND_DATAGRAM +{ + TDI_REQUEST Request; + PTDI_CONNECTION_INFORMATION SendDatagramInformation; +} TDI_REQUEST_SEND_DATAGRAM, *PTDI_REQUEST_SEND_DATAGRAM; + +typedef struct _AFD_SEND_DATAGRAM_INFO +{ + LPWSABUF BufferArray; + ULONG BufferCount; + ULONG AfdFlags; + TDI_REQUEST_SEND_DATAGRAM TdiRequest; + TDI_CONNECTION_INFORMATION TdiConnInfo; +} AFD_SEND_DATAGRAM_INFO, *PAFD_SEND_DATAGRAM_INFO; namespace { struct afd_endpoint : stateless_device { + network::socket s{AF_INET}; + NTSTATUS io_control(const io_device_context& c) override { c.win_emu.logger.print(color::cyan, "AFD IOCTL: %X\n", c.io_control_code); + + switch (c.io_control_code) + { + case 0x12003: + return this->ioctl_bind(c); + case 0x12023: + return this->ioctl_send_datagram(c); + case 0x12047: // ? + case 0x1207B: // ? + return STATUS_SUCCESS; + } + + return STATUS_SUCCESS; + } + + NTSTATUS ioctl_bind(const io_device_context& c) + { + std::vector data{}; + data.resize(c.input_buffer_length); + c.emu.read_memory(c.input_buffer, data.data(), c.input_buffer_length); + + utils::buffer_deserializer deserializer{data, true}; + deserializer.read(); // IDK :( + const network::address addr = deserializer.read(); + + if (!this->s.bind_port(addr)) + { + return STATUS_ADDRESS_ALREADY_ASSOCIATED; + } + + return STATUS_SUCCESS; + } + + NTSTATUS ioctl_send_datagram(const io_device_context& c) + { + if (c.input_buffer_length < sizeof(AFD_SEND_DATAGRAM_INFO)) + { + return STATUS_BUFFER_TOO_SMALL; + } + + const auto send_info = emulator_object{c.emu, c.input_buffer}.read(); + const auto buffer = emulator_object{c.emu, send_info.BufferArray}.read(0); + + std::vector address{}; + address.resize(send_info.TdiConnInfo.RemoteAddressLength); + c.emu.read_memory(reinterpret_cast(send_info.TdiConnInfo.RemoteAddress), address.data(), + address.size()); + + const network::address target(reinterpret_cast(address.data()), + static_cast(address.size())); + + std::vector data{}; + data.resize(buffer.len); + c.emu.read_memory(reinterpret_cast(buffer.buf), data.data(), data.size()); + + if (!s.send(target, data.data(), data.size())) + { + return STATUS_CONNECTION_REFUSED; + } + return STATUS_SUCCESS; } };