diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index 8820c4c..4c24a72 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -61,116 +61,130 @@ namespace return {std::move(poll_info), std::move(handle_info)}; } - std::optional perform_poll(windows_emulator& win_emu, const io_device_context& c, - const std::span endpoints, - const std::span handles) + int16_t map_afd_request_events_to_socket(const ULONG poll_events) { - std::vector poll_data{}; - poll_data.resize(endpoints.size()); + int16_t socket_events{}; - for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i) + if (poll_events & (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE)) { - auto& pfd = poll_data.at(i); - auto& handle = handles[i]; + socket_events |= POLLRDNORM; + } - if (handle.PollEvents & (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE)) - { - pfd.events |= POLLRDNORM; - } + if (poll_events & AFD_POLL_RECEIVE_EXPEDITED) + { + socket_events |= POLLRDNORM; + } - if (handle.PollEvents & AFD_POLL_RECEIVE_EXPEDITED) - { - pfd.events |= POLLRDNORM; - } + if (poll_events & AFD_POLL_RECEIVE_EXPEDITED) + { + socket_events |= POLLRDBAND; + } - if (handle.PollEvents & AFD_POLL_RECEIVE_EXPEDITED) - { - pfd.events |= POLLRDBAND; - } + if (poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND)) + { + socket_events |= POLLWRNORM; + } - if (handle.PollEvents & (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND)) - { - pfd.events |= POLLWRNORM; - } + return socket_events; + } - pfd.fd = endpoints[i]; - pfd.events = POLLIN; - pfd.revents = pfd.events; - } + ULONG map_socket_response_events_to_afd(const int16_t socket_events) + { + ULONG afd_events = 0; - const auto count = poll(poll_data.data(), static_cast(poll_data.size()), 0); - if (count > 0) + if (socket_events & POLLRDNORM) { - constexpr auto info_size = offsetof(AFD_POLL_INFO, Handles); - const emulator_object handle_info_obj{win_emu.emu(), c.input_buffer + info_size}; + afd_events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE); + } - size_t current_index = 0; + if (socket_events & POLLRDBAND) + { + afd_events |= AFD_POLL_RECEIVE_EXPEDITED; + } - for (size_t i = 0; i < endpoints.size(); ++i) - { - const auto& pfd = poll_data.at(i); - if (pfd.revents == 0) - { - continue; - } + if (socket_events & POLLWRNORM) + { + afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND); + } - ULONG events = 0; + if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR)) + { + afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT); + } + else if (socket_events & POLLHUP) + { + afd_events |= AFD_POLL_DISCONNECT; + } - if (pfd.revents & POLLRDNORM) - { - events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE); - } + if (socket_events & POLLNVAL) + { + afd_events |= AFD_POLL_LOCAL_CLOSE; + } - if (pfd.revents & POLLRDBAND) - { - events |= AFD_POLL_RECEIVE_EXPEDITED; - } + return afd_events; + } - if (pfd.revents & POLLWRNORM) - { - events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND); - } + NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c, + const std::span endpoints, + const std::span handles) + { + std::vector poll_data{}; + poll_data.resize(endpoints.size()); - if ((pfd.revents & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR)) - { - events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT); - } - else if (pfd.revents & POLLHUP) - { - events |= AFD_POLL_DISCONNECT; - } + for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i) + { + auto& pfd = poll_data.at(i); + auto& handle = handles[i]; - if (pfd.revents & POLLNVAL) - { - events |= AFD_POLL_LOCAL_CLOSE; - } + pfd.fd = endpoints[i]; + pfd.events = map_afd_request_events_to_socket(handle.PollEvents); + pfd.revents = pfd.events; + } - auto entry = handle_info_obj.read(i); - entry.PollEvents = events; - entry.Status = STATUS_SUCCESS; + const auto count = poll(poll_data.data(), static_cast(poll_data.size()), 0); + if (count <= 0) + { + return STATUS_PENDING; + } - handle_info_obj.write(entry, current_index++); - break; - } + constexpr auto info_size = offsetof(AFD_POLL_INFO, Handles); + const emulator_object handle_info_obj{win_emu.emu(), c.input_buffer + info_size}; - assert(current_index == static_cast(count)); + size_t current_index = 0; - emulator_object{win_emu.emu(), c.input_buffer}.access([&](AFD_POLL_INFO& info) + for (size_t i = 0; i < endpoints.size(); ++i) + { + const auto& pfd = poll_data.at(i); + if (pfd.revents == 0) { - info.NumberOfHandles = count; - }); + continue; + } - return STATUS_SUCCESS; + auto entry = handle_info_obj.read(i); + entry.PollEvents = map_socket_response_events_to_afd(pfd.revents); + entry.Status = STATUS_SUCCESS; + + handle_info_obj.write(entry, current_index++); + break; } - return {}; + assert(current_index == static_cast(count)); + + emulator_object{win_emu.emu(), c.input_buffer}.access([&](AFD_POLL_INFO& info) + { + info.NumberOfHandles = static_cast(current_index); + }); + + return STATUS_SUCCESS; } struct afd_endpoint : io_device { - bool in_poll{}; - std::optional s{}; - std::optional delayed_ioctl{}; + bool executing_delayed_ioctl_{}; + std::optional s_{}; + std::optional require_poll_{}; + std::optional delayed_ioctl_{}; + std::optional timeout_{}; afd_endpoint() { @@ -182,9 +196,9 @@ namespace ~afd_endpoint() override { - if (this->s) + if (this->s_) { - closesocket(*this->s); + closesocket(*this->s_); } } @@ -200,31 +214,70 @@ namespace network::socket::set_blocking(sock, false); - s = sock; + s_ = sock; } - void work(windows_emulator& win_emu) override + void delay_ioctrl(const io_device_context& c, + const std::optional timeout = {}, + const std::optional require_poll = {}) { - if (!this->delayed_ioctl || !this->s) + if (this->executing_delayed_ioctl_) { return; } - const auto is_ready = network::socket::is_socket_ready(*this->s, this->in_poll); - if (!is_ready) + this->timeout_ = timeout; + this->require_poll_ = require_poll; + this->delayed_ioctl_ = c; + } + + void clear_pending_state() + { + this->timeout_ = {}; + this->require_poll_ = {}; + this->delayed_ioctl_ = {}; + } + + void work(windows_emulator& win_emu) override + { + if (!this->delayed_ioctl_ || !this->s_) { return; } - this->execute_ioctl(win_emu, *this->delayed_ioctl); + this->executing_delayed_ioctl_ = true; + const auto _ = utils::finally([&] + { + this->executing_delayed_ioctl_ = false; + }); + + if (this->require_poll_.has_value()) + { + const auto is_ready = network::socket::is_socket_ready(*this->s_, *this->require_poll_); + if (!is_ready) + { + return; + } + } + + const auto status = this->execute_ioctl(win_emu, *this->delayed_ioctl_); + if (status == STATUS_PENDING) + { + if (!this->timeout_ || this->timeout_ > std::chrono::steady_clock::now()) + { + return; + } - auto* e = win_emu.process().events.get(this->delayed_ioctl->event); + write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT); + } + + auto* e = win_emu.process().events.get(this->delayed_ioctl_->event); if (e) { e->signaled = true; } - this->delayed_ioctl = {}; + this->clear_pending_state(); } void deserialize(utils::buffer_deserializer&) override @@ -258,7 +311,7 @@ namespace case AFD_RECEIVE_DATAGRAM: return this->ioctl_receive_datagram(win_emu, c); case AFD_POLL: - return ioctl_poll(win_emu, c); + return this->ioctl_poll(win_emu, c); case AFD_SET_CONTEXT: case AFD_GET_INFORMATION: return STATUS_SUCCESS; @@ -284,7 +337,7 @@ namespace const network::address addr(address, address_size); - if (bind(*this->s, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) + if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) { return STATUS_ADDRESS_ALREADY_ASSOCIATED; } @@ -314,51 +367,34 @@ namespace throw std::runtime_error("Device is not an AFD endpoint!"); } - endpoints.push_back(*endpoint->s); + endpoints.push_back(*endpoint->s_); } return endpoints; } - static bool is_poll_done(windows_emulator& win_emu, const io_device_context& c, emulator_thread& t, - const std::optional timeout) + NTSTATUS ioctl_poll(windows_emulator& win_emu, const io_device_context& c) { const auto [info, handles] = get_poll_info(win_emu, c); const auto endpoints = resolve_endpoints(win_emu, handles); const auto status = perform_poll(win_emu, c, endpoints, handles); - if (status) + if (status != STATUS_PENDING) { - t.pending_status = *status; - return true; + return status; } - if (timeout && *timeout < std::chrono::steady_clock::now()) + if (!this->executing_delayed_ioctl_) { - t.pending_status = STATUS_TIMEOUT; - return true; - } - - return false; - } - - static NTSTATUS ioctl_poll(windows_emulator& win_emu, const io_device_context& c) - { - const auto [info, handles] = get_poll_info(win_emu, c); - (void)resolve_endpoints(win_emu, handles); + std::optional timeout{}; + if (info.Timeout.QuadPart) + { + timeout = convert_delay_interval_to_time_point(info.Timeout); + } - std::optional timeout{}; - if (info.Timeout.QuadPart) - { - timeout = convert_delay_interval_to_time_point(info.Timeout); + this->delay_ioctrl(c, timeout); } - win_emu.process().active_thread->thread_blocker = [timeout, c](windows_emulator& emu, emulator_thread& t) - { - return is_poll_done(emu, c, t, timeout); - }; - - win_emu.yield_thread(); return STATUS_PENDING; } @@ -394,7 +430,7 @@ namespace std::vector data{}; data.resize(buffer.len); - const auto recevied_data = recvfrom(*this->s, data.data(), static_cast(data.size()), 0, + const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, reinterpret_cast(address.data()), &fromlength); if (recevied_data < 0) @@ -402,8 +438,7 @@ namespace const auto error = GET_SOCKET_ERROR(); if (error == SOCK_WOULDBLOCK) { - this->in_poll = true; - this->delayed_ioctl = c; + this->delay_ioctrl(c, {}, true); return STATUS_PENDING; } @@ -431,7 +466,7 @@ namespace NTSTATUS ioctl_send_datagram(windows_emulator& win_emu, const io_device_context& c) { - auto& emu = win_emu.emu(); + const auto& emu = win_emu.emu(); if (c.input_buffer_length < sizeof(AFD_SEND_DATAGRAM_INFO)) { @@ -449,7 +484,7 @@ namespace const auto data = emu.read_memory(buffer.buf, buffer.len); - const auto sent_data = sendto(*this->s, reinterpret_cast(data.data()), + const auto sent_data = sendto(*this->s_, reinterpret_cast(data.data()), static_cast(data.size()), 0 /* ? */, &target.get_addr(), target.get_size()); @@ -458,8 +493,7 @@ namespace const auto error = GET_SOCKET_ERROR(); if (error == SOCK_WOULDBLOCK) { - this->in_poll = false; - this->delayed_ioctl = c; + this->delay_ioctrl(c, {}, false); return STATUS_PENDING; } diff --git a/src/windows-emulator/io_device.hpp b/src/windows-emulator/io_device.hpp index 01986f6..fe06c74 100644 --- a/src/windows-emulator/io_device.hpp +++ b/src/windows-emulator/io_device.hpp @@ -29,6 +29,17 @@ struct io_device_creation_data uint32_t length; }; +inline void write_io_status(const emulator_object io_status_block, const NTSTATUS status) +{ + if (io_status_block) + { + io_status_block.access([&](IO_STATUS_BLOCK& status_block) + { + status_block.Status = status; + }); + } +} + struct io_device { io_device() = default; @@ -64,15 +75,7 @@ struct io_device } const auto result = this->io_control(win_emu, c); - - if (c.io_status_block) - { - c.io_status_block.access([&](IO_STATUS_BLOCK& status) - { - status.Status = result; - }); - } - + write_io_status(c.io_status_block, result); return result; } }; diff --git a/src/windows-emulator/process_context.hpp b/src/windows-emulator/process_context.hpp index 18d8474..8fca160 100644 --- a/src/windows-emulator/process_context.hpp +++ b/src/windows-emulator/process_context.hpp @@ -222,9 +222,6 @@ class emulator_thread : ref_counted_object bool alerted{false}; std::optional await_time{}; - // TODO: Get rid of that! - std::function thread_blocker{}; - std::optional pending_status{}; std::optional gs_segment; diff --git a/src/windows-emulator/windows_emulator.cpp b/src/windows-emulator/windows_emulator.cpp index dbb6f9a..f23277c 100644 --- a/src/windows-emulator/windows_emulator.cpp +++ b/src/windows-emulator/windows_emulator.cpp @@ -686,18 +686,6 @@ bool emulator_thread::is_thread_ready(windows_emulator& win_emu) return false; } - if (this->thread_blocker) - { - const auto res = this->thread_blocker(win_emu, *this); - if (res) - { - this->thread_blocker = {}; - return true; - } - - return false; - } - if (this->waiting_for_alert) { if (this->alerted)