diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index dfad092..4c24a72 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -2,6 +2,7 @@ #include "afd_types.hpp" #include "../windows_emulator.hpp" +#include "../syscall_utils.hpp" #include #include @@ -31,11 +32,159 @@ namespace return win_emu.emu().read_memory(data.buffer); } + std::pair> get_poll_info( + windows_emulator& win_emu, const io_device_context& c) + { + constexpr auto info_size = offsetof(AFD_POLL_INFO, Handles); + if (!c.input_buffer || c.input_buffer_length < info_size) + { + throw std::runtime_error("Bad AFD poll data"); + } + + AFD_POLL_INFO poll_info{}; + win_emu.emu().read_memory(c.input_buffer, &poll_info, info_size); + + std::vector handle_info{}; + + const emulator_object handle_info_obj{win_emu.emu(), c.input_buffer + info_size}; + + if (c.input_buffer_length < (info_size + sizeof(AFD_POLL_HANDLE_INFO) * poll_info.NumberOfHandles)) + { + throw std::runtime_error("Bad AFD poll handle data"); + } + + for (ULONG i = 0; i < poll_info.NumberOfHandles; ++i) + { + handle_info.emplace_back(handle_info_obj.read(i)); + } + + return {std::move(poll_info), std::move(handle_info)}; + } + + int16_t map_afd_request_events_to_socket(const ULONG poll_events) + { + int16_t socket_events{}; + + if (poll_events & (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE)) + { + socket_events |= POLLRDNORM; + } + + if (poll_events & AFD_POLL_RECEIVE_EXPEDITED) + { + socket_events |= POLLRDNORM; + } + + if (poll_events & AFD_POLL_RECEIVE_EXPEDITED) + { + socket_events |= POLLRDBAND; + } + + if (poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND)) + { + socket_events |= POLLWRNORM; + } + + return socket_events; + } + + ULONG map_socket_response_events_to_afd(const int16_t socket_events) + { + ULONG afd_events = 0; + + if (socket_events & POLLRDNORM) + { + afd_events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE); + } + + if (socket_events & POLLRDBAND) + { + afd_events |= AFD_POLL_RECEIVE_EXPEDITED; + } + + if (socket_events & POLLWRNORM) + { + afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND); + } + + if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR)) + { + afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT); + } + else if (socket_events & POLLHUP) + { + afd_events |= AFD_POLL_DISCONNECT; + } + + if (socket_events & POLLNVAL) + { + afd_events |= AFD_POLL_LOCAL_CLOSE; + } + + return afd_events; + } + + NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c, + const std::span endpoints, + const std::span handles) + { + std::vector poll_data{}; + poll_data.resize(endpoints.size()); + + for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i) + { + auto& pfd = poll_data.at(i); + auto& handle = handles[i]; + + pfd.fd = endpoints[i]; + pfd.events = map_afd_request_events_to_socket(handle.PollEvents); + pfd.revents = pfd.events; + } + + const auto count = poll(poll_data.data(), static_cast(poll_data.size()), 0); + if (count <= 0) + { + return STATUS_PENDING; + } + + constexpr auto info_size = offsetof(AFD_POLL_INFO, Handles); + const emulator_object handle_info_obj{win_emu.emu(), c.input_buffer + info_size}; + + size_t current_index = 0; + + for (size_t i = 0; i < endpoints.size(); ++i) + { + const auto& pfd = poll_data.at(i); + if (pfd.revents == 0) + { + continue; + } + + auto entry = handle_info_obj.read(i); + entry.PollEvents = map_socket_response_events_to_afd(pfd.revents); + entry.Status = STATUS_SUCCESS; + + handle_info_obj.write(entry, current_index++); + break; + } + + assert(current_index == static_cast(count)); + + emulator_object{win_emu.emu(), c.input_buffer}.access([&](AFD_POLL_INFO& info) + { + info.NumberOfHandles = static_cast(current_index); + }); + + return STATUS_SUCCESS; + } + struct afd_endpoint : io_device { - bool in_poll{}; - std::optional s{}; - std::optional delayed_ioctl{}; + bool executing_delayed_ioctl_{}; + std::optional s_{}; + std::optional require_poll_{}; + std::optional delayed_ioctl_{}; + std::optional timeout_{}; afd_endpoint() { @@ -47,9 +196,9 @@ namespace ~afd_endpoint() override { - if (this->s) + if (this->s_) { - closesocket(*this->s); + closesocket(*this->s_); } } @@ -65,31 +214,70 @@ namespace network::socket::set_blocking(sock, false); - s = sock; + s_ = sock; } - void work(windows_emulator& win_emu) override + void delay_ioctrl(const io_device_context& c, + const std::optional timeout = {}, + const std::optional require_poll = {}) { - if (!this->delayed_ioctl || !this->s) + if (this->executing_delayed_ioctl_) { return; } - const auto is_ready = network::socket::is_socket_ready(*this->s, this->in_poll); - if (!is_ready) + this->timeout_ = timeout; + this->require_poll_ = require_poll; + this->delayed_ioctl_ = c; + } + + void clear_pending_state() + { + this->timeout_ = {}; + this->require_poll_ = {}; + this->delayed_ioctl_ = {}; + } + + void work(windows_emulator& win_emu) override + { + if (!this->delayed_ioctl_ || !this->s_) { return; } - this->execute_ioctl(win_emu, *this->delayed_ioctl); + this->executing_delayed_ioctl_ = true; + const auto _ = utils::finally([&] + { + this->executing_delayed_ioctl_ = false; + }); + + if (this->require_poll_.has_value()) + { + const auto is_ready = network::socket::is_socket_ready(*this->s_, *this->require_poll_); + if (!is_ready) + { + return; + } + } + + const auto status = this->execute_ioctl(win_emu, *this->delayed_ioctl_); + if (status == STATUS_PENDING) + { + if (!this->timeout_ || this->timeout_ > std::chrono::steady_clock::now()) + { + return; + } + + write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT); + } - auto* e = win_emu.process().events.get(this->delayed_ioctl->event); + auto* e = win_emu.process().events.get(this->delayed_ioctl_->event); if (e) { e->signaled = true; } - this->delayed_ioctl = {}; + this->clear_pending_state(); } void deserialize(utils::buffer_deserializer&) override @@ -122,14 +310,15 @@ namespace return this->ioctl_send_datagram(win_emu, c); case AFD_RECEIVE_DATAGRAM: return this->ioctl_receive_datagram(win_emu, c); + case AFD_POLL: + return this->ioctl_poll(win_emu, c); case AFD_SET_CONTEXT: - return STATUS_SUCCESS; case AFD_GET_INFORMATION: return STATUS_SUCCESS; + default: + win_emu.logger.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code); + return STATUS_NOT_SUPPORTED; } - - win_emu.logger.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code); - return STATUS_NOT_SUPPORTED; } NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const @@ -148,7 +337,7 @@ namespace const network::address addr(address, address_size); - if (bind(*this->s, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) + if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) { return STATUS_ADDRESS_ALREADY_ASSOCIATED; } @@ -156,6 +345,59 @@ namespace return STATUS_SUCCESS; } + static std::vector resolve_endpoints(windows_emulator& win_emu, + const std::span handles) + { + auto& proc = win_emu.process(); + + std::vector endpoints{}; + endpoints.reserve(handles.size()); + + for (const auto& handle : handles) + { + auto* device = proc.devices.get(reinterpret_cast(handle.Handle)); + if (!device) + { + throw std::runtime_error("Bad device!"); + } + + const auto* endpoint = device->get_internal_device(); + if (!endpoint) + { + throw std::runtime_error("Device is not an AFD endpoint!"); + } + + endpoints.push_back(*endpoint->s_); + } + + return endpoints; + } + + NTSTATUS ioctl_poll(windows_emulator& win_emu, const io_device_context& c) + { + const auto [info, handles] = get_poll_info(win_emu, c); + const auto endpoints = resolve_endpoints(win_emu, handles); + + const auto status = perform_poll(win_emu, c, endpoints, handles); + if (status != STATUS_PENDING) + { + return status; + } + + if (!this->executing_delayed_ioctl_) + { + std::optional timeout{}; + if (info.Timeout.QuadPart) + { + timeout = convert_delay_interval_to_time_point(info.Timeout); + } + + this->delay_ioctrl(c, timeout); + } + + return STATUS_PENDING; + } + NTSTATUS ioctl_receive_datagram(windows_emulator& win_emu, const io_device_context& c) { auto& emu = win_emu.emu(); @@ -188,7 +430,7 @@ namespace std::vector data{}; data.resize(buffer.len); - const auto recevied_data = recvfrom(*this->s, data.data(), static_cast(data.size()), 0, + const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, reinterpret_cast(address.data()), &fromlength); if (recevied_data < 0) @@ -196,8 +438,7 @@ namespace const auto error = GET_SOCKET_ERROR(); if (error == SOCK_WOULDBLOCK) { - this->in_poll = true; - this->delayed_ioctl = c; + this->delay_ioctrl(c, {}, true); return STATUS_PENDING; } @@ -225,7 +466,7 @@ namespace NTSTATUS ioctl_send_datagram(windows_emulator& win_emu, const io_device_context& c) { - auto& emu = win_emu.emu(); + const auto& emu = win_emu.emu(); if (c.input_buffer_length < sizeof(AFD_SEND_DATAGRAM_INFO)) { @@ -243,7 +484,7 @@ namespace const auto data = emu.read_memory(buffer.buf, buffer.len); - const auto sent_data = sendto(*this->s, reinterpret_cast(data.data()), + const auto sent_data = sendto(*this->s_, reinterpret_cast(data.data()), static_cast(data.size()), 0 /* ? */, &target.get_addr(), target.get_size()); @@ -252,8 +493,7 @@ namespace const auto error = GET_SOCKET_ERROR(); if (error == SOCK_WOULDBLOCK) { - this->in_poll = false; - this->delayed_ioctl = c; + this->delay_ioctrl(c, {}, false); return STATUS_PENDING; } diff --git a/src/windows-emulator/devices/afd_types.hpp b/src/windows-emulator/devices/afd_types.hpp index a168380..944ae5e 100644 --- a/src/windows-emulator/devices/afd_types.hpp +++ b/src/windows-emulator/devices/afd_types.hpp @@ -70,6 +70,47 @@ typedef struct _AFD_RECV_DATAGRAM_INFO PULONG AddressLength; } AFD_RECV_DATAGRAM_INFO, *PAFD_RECV_DATAGRAM_INFO; +typedef struct _AFD_POLL_HANDLE_INFO +{ + HANDLE Handle; + ULONG PollEvents; + NTSTATUS Status; +} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO; + +typedef struct _AFD_POLL_INFO +{ + LARGE_INTEGER Timeout; + ULONG NumberOfHandles; + BOOLEAN Unique; + AFD_POLL_HANDLE_INFO Handles[1]; +} AFD_POLL_INFO, *PAFD_POLL_INFO; + +#define AFD_POLL_RECEIVE_BIT 0 +#define AFD_POLL_RECEIVE (1 << AFD_POLL_RECEIVE_BIT) +#define AFD_POLL_RECEIVE_EXPEDITED_BIT 1 +#define AFD_POLL_RECEIVE_EXPEDITED (1 << AFD_POLL_RECEIVE_EXPEDITED_BIT) +#define AFD_POLL_SEND_BIT 2 +#define AFD_POLL_SEND (1 << AFD_POLL_SEND_BIT) +#define AFD_POLL_DISCONNECT_BIT 3 +#define AFD_POLL_DISCONNECT (1 << AFD_POLL_DISCONNECT_BIT) +#define AFD_POLL_ABORT_BIT 4 +#define AFD_POLL_ABORT (1 << AFD_POLL_ABORT_BIT) +#define AFD_POLL_LOCAL_CLOSE_BIT 5 +#define AFD_POLL_LOCAL_CLOSE (1 << AFD_POLL_LOCAL_CLOSE_BIT) +#define AFD_POLL_CONNECT_BIT 6 +#define AFD_POLL_CONNECT (1 << AFD_POLL_CONNECT_BIT) +#define AFD_POLL_ACCEPT_BIT 7 +#define AFD_POLL_ACCEPT (1 << AFD_POLL_ACCEPT_BIT) +#define AFD_POLL_CONNECT_FAIL_BIT 8 +#define AFD_POLL_CONNECT_FAIL (1 << AFD_POLL_CONNECT_FAIL_BIT) +#define AFD_POLL_QOS_BIT 9 +#define AFD_POLL_QOS (1 << AFD_POLL_QOS_BIT) +#define AFD_POLL_GROUP_QOS_BIT 10 +#define AFD_POLL_GROUP_QOS (1 << AFD_POLL_GROUP_QOS_BIT) + +#define AFD_NUM_POLL_EVENTS 11 +#define AFD_POLL_ALL ((1 << AFD_NUM_POLL_EVENTS) - 1) + #define _AFD_REQUEST(ioctl) \ ((((ULONG)(ioctl)) >> 2) & 0x03FF) #define _AFD_BASE(ioctl) \ diff --git a/src/windows-emulator/io_device.hpp b/src/windows-emulator/io_device.hpp index 01986f6..fe06c74 100644 --- a/src/windows-emulator/io_device.hpp +++ b/src/windows-emulator/io_device.hpp @@ -29,6 +29,17 @@ struct io_device_creation_data uint32_t length; }; +inline void write_io_status(const emulator_object io_status_block, const NTSTATUS status) +{ + if (io_status_block) + { + io_status_block.access([&](IO_STATUS_BLOCK& status_block) + { + status_block.Status = status; + }); + } +} + struct io_device { io_device() = default; @@ -64,15 +75,7 @@ struct io_device } const auto result = this->io_control(win_emu, c); - - if (c.io_status_block) - { - c.io_status_block.access([&](IO_STATUS_BLOCK& status) - { - status.Status = result; - }); - } - + write_io_status(c.io_status_block, result); return result; } }; diff --git a/src/windows-emulator/process_context.hpp b/src/windows-emulator/process_context.hpp index 4bb6a23..8fca160 100644 --- a/src/windows-emulator/process_context.hpp +++ b/src/windows-emulator/process_context.hpp @@ -26,6 +26,8 @@ #define GDT_LIMIT 0x1000 #define GDT_ENTRY_SIZE 0x8 +class windows_emulator; + struct ref_counted_object { uint32_t ref_count{1}; @@ -234,7 +236,7 @@ class emulator_thread : ref_counted_object return this->await_time.has_value() && this->await_time.value() < std::chrono::steady_clock::now(); } - bool is_thread_ready(process_context& context); + bool is_thread_ready(windows_emulator& win_emu); void save(x64_emulator& emu) { diff --git a/src/windows-emulator/windows_emulator.cpp b/src/windows-emulator/windows_emulator.cpp index f8b89fd..f23277c 100644 --- a/src/windows-emulator/windows_emulator.cpp +++ b/src/windows-emulator/windows_emulator.cpp @@ -520,7 +520,7 @@ namespace auto& emu = win_emu.emu(); auto& context = win_emu.process(); - if (!thread.is_thread_ready(context)) + if (!thread.is_thread_ready(win_emu)) { return false; } @@ -679,7 +679,7 @@ void emulator_thread::mark_as_ready(const NTSTATUS status) this->waiting_for_alert = false; } -bool emulator_thread::is_thread_ready(process_context& context) +bool emulator_thread::is_thread_ready(windows_emulator& win_emu) { if (this->exit_status.has_value()) { @@ -704,7 +704,7 @@ bool emulator_thread::is_thread_ready(process_context& context) if (this->await_object.has_value()) { - if (is_object_signaled(context, *this->await_object)) + if (is_object_signaled(win_emu.process(), *this->await_object)) { this->mark_as_ready(STATUS_WAIT_0); return true;