diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp index 0e68ba5..7f385df 100644 --- a/src/common/network/socket.cpp +++ b/src/common/network/socket.cpp @@ -2,13 +2,6 @@ #include -#ifdef _WIN32 -#define poll WSAPoll -#define SOCK_WOULDBLOCK WSAEWOULDBLOCK -#else -#define SOCK_WOULDBLOCK EWOULDBLOCK -#endif - using namespace std::literals; namespace network @@ -103,15 +96,20 @@ namespace network } bool socket::set_blocking(const bool blocking) + { + return socket::set_blocking(this->socket_, blocking); + } + + bool socket::set_blocking(SOCKET s, const bool blocking) { #ifdef _WIN32 unsigned long mode = blocking ? 0 : 1; - return ioctlsocket(this->socket_, FIONBIO, &mode) == 0; + return ioctlsocket(s, FIONBIO, &mode) == 0; #else - int flags = fcntl(this->socket_, F_GETFL, 0); + int flags = fcntl(s, F_GETFL, 0); if (flags == -1) return false; flags = blocking ? (flags & ~O_NONBLOCK) : (flags | O_NONBLOCK); - return fcntl(this->socket_, F_SETFL, flags) == 0; + return fcntl(s, F_SETFL, flags) == 0; #endif } @@ -200,6 +198,30 @@ namespace network return !socket_is_ready; } + bool socket::is_socket_ready(const SOCKET s, const bool in_poll) + { + pollfd pfd{}; + + pfd.fd = s; + pfd.events = in_poll ? POLLIN : POLLOUT; + pfd.revents = 0; + + const auto retval = poll(&pfd, 1, 0); + + if (retval == SOCKET_ERROR) + { + std::this_thread::sleep_for(1ms); + return socket_is_ready; + } + + if (retval > 0) + { + return socket_is_ready; + } + + return !socket_is_ready; + } + bool socket::sleep_sockets_until(const std::span& sockets, const std::chrono::high_resolution_clock::time_point time_point) { diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp index b66029c..7c8cded 100644 --- a/src/common/network/socket.hpp +++ b/src/common/network/socket.hpp @@ -8,12 +8,15 @@ #ifdef _WIN32 using socklen_t = int; #define GET_SOCKET_ERROR() (WSAGetLastError()) +#define poll WSAPoll +#define SOCK_WOULDBLOCK WSAEWOULDBLOCK #else - using SOCKET = int; +using SOCKET = int; #define INVALID_SOCKET (SOCKET)(~0) #define SOCKET_ERROR (-1) #define GET_SOCKET_ERROR() (errno) #define closesocket close +#define SOCK_WOULDBLOCK EWOULDBLOCK #endif namespace network @@ -39,6 +42,7 @@ namespace network bool receive(address& source, std::string& data) const; bool set_blocking(bool blocking); + static bool set_blocking(SOCKET s, bool blocking); static constexpr bool socket_is_ready = true; bool sleep(std::chrono::milliseconds timeout) const; @@ -53,6 +57,8 @@ namespace network static bool sleep_sockets_until(const std::span& sockets, std::chrono::high_resolution_clock::time_point time_point); + static bool is_socket_ready(SOCKET s, bool in_poll); + private: int address_family_{AF_UNSPEC}; uint16_t port_ = 0; diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index 97f3b9d..59522eb 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -6,6 +6,8 @@ #include #include +#include + namespace { struct afd_creation_data @@ -19,19 +21,21 @@ namespace // ... }; - afd_creation_data get_creation_data(const io_device_creation_data& data) + afd_creation_data get_creation_data(windows_emulator& win_emu, const io_device_creation_data& data) { if (!data.buffer || data.length < sizeof(afd_creation_data)) { throw std::runtime_error("Bad AFD creation data"); } - return emulator_object{data.emu, data.buffer}.read(); + return win_emu.emu().read_memory(data.buffer); } struct afd_endpoint : io_device { + bool in_poll{}; std::optional s{}; + std::optional delayed_ioctl{}; afd_endpoint() { @@ -49,9 +53,9 @@ namespace } } - void create(const io_device_creation_data& data) override + void create(windows_emulator& win_emu, const io_device_creation_data& data) override { - const auto creation_data = get_creation_data(data); + const auto creation_data = get_creation_data(win_emu, data); // TODO: values map to windows values; might not be the case for other platforms const auto sock = socket(creation_data.address_family, creation_data.type, creation_data.protocol); if (sock == INVALID_SOCKET) @@ -59,9 +63,35 @@ namespace throw std::runtime_error("Failed to create socket!"); } + network::socket::set_blocking(sock, false); + s = sock; } + void work(windows_emulator& win_emu) override + { + if (!this->delayed_ioctl || !this->s) + { + return; + } + + const auto is_ready = network::socket::is_socket_ready(*this->s, this->in_poll); + if (!is_ready) + { + return; + } + + this->execute_ioctl(win_emu, *this->delayed_ioctl); + + auto* e = win_emu.process().events.get(this->delayed_ioctl->event); + if (e) + { + e->signaled = true; + } + + this->delayed_ioctl = {}; + } + void deserialize(utils::buffer_deserializer&) override { // TODO @@ -72,43 +102,41 @@ namespace // TODO } - NTSTATUS io_control(const io_device_context& c) override + NTSTATUS io_control(windows_emulator& win_emu, const io_device_context& c) override { - c.io_status_block.write({}); - if (_AFD_BASE(c.io_control_code) != FSCTL_AFD_BASE) { - c.win_emu.logger.print(color::cyan, "Bad AFD IOCTL: %X\n", c.io_control_code); + win_emu.logger.print(color::cyan, "Bad AFD IOCTL: %X\n", c.io_control_code); return STATUS_NOT_SUPPORTED; } - c.win_emu.logger.print(color::cyan, "AFD IOCTL: %X\n", c.io_control_code); + win_emu.logger.print(color::cyan, "AFD IOCTL: %X\n", c.io_control_code); const auto request = _AFD_REQUEST(c.io_control_code); switch (request) { case AFD_BIND: - return this->ioctl_bind(c); + return this->ioctl_bind(win_emu, c); case AFD_SEND_DATAGRAM: - return this->ioctl_send_datagram(c); + return this->ioctl_send_datagram(win_emu, c); case AFD_RECEIVE_DATAGRAM: - return this->ioctl_receive_datagram(c); + return this->ioctl_receive_datagram(win_emu, c); case AFD_SET_CONTEXT: return STATUS_SUCCESS; case AFD_GET_INFORMATION: return STATUS_SUCCESS; } - c.win_emu.logger.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code); + win_emu.logger.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code); return STATUS_NOT_SUPPORTED; } - NTSTATUS ioctl_bind(const io_device_context& c) const + NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const { std::vector data{}; data.resize(c.input_buffer_length); - c.emu.read_memory(c.input_buffer, data.data(), c.input_buffer_length); + win_emu.emu().read_memory(c.input_buffer, data.data(), c.input_buffer_length); constexpr auto address_offset = 4; @@ -130,22 +158,24 @@ namespace return STATUS_SUCCESS; } - NTSTATUS ioctl_receive_datagram(const io_device_context& c) const + NTSTATUS ioctl_receive_datagram(windows_emulator& win_emu, const io_device_context& c) { + auto& emu = win_emu.emu(); + if (c.input_buffer_length < sizeof(AFD_RECV_DATAGRAM_INFO)) { return STATUS_BUFFER_TOO_SMALL; } - const auto receive_info = emulator_object{c.emu, c.input_buffer}.read(); - const auto buffer = emulator_object{c.emu, receive_info.BufferArray}.read(0); + const auto receive_info = emu.read_memory(c.input_buffer); + const auto buffer = emu.read_memory(receive_info.BufferArray); std::vector address{}; ULONG address_length = 0x1000; if (receive_info.AddressLength) { - address_length = c.emu.read_memory(receive_info.AddressLength); + address_length = emu.read_memory(receive_info.AddressLength); } address.resize(std::clamp(address_length, 1UL, 0x1000UL)); @@ -165,16 +195,24 @@ namespace if (recevied_data < 0) { + const auto error = GET_SOCKET_ERROR(); + if (error == SOCK_WOULDBLOCK) + { + this->in_poll = true; + this->delayed_ioctl = c; + return STATUS_PENDING; + } + return STATUS_UNSUCCESSFUL; } - c.emu.write_memory(reinterpret_cast(buffer.buf), data.data(), - std::min(data.size(), static_cast(recevied_data))); + emu.write_memory(reinterpret_cast(buffer.buf), data.data(), + std::min(data.size(), static_cast(recevied_data))); if (receive_info.Address && address_length) { - c.emu.write_memory(reinterpret_cast(receive_info.Address), address.data(), - std::min(address.size(), static_cast(address_length))); + emu.write_memory(reinterpret_cast(receive_info.Address), address.data(), + std::min(address.size(), static_cast(address_length))); } if (c.io_status_block) @@ -187,27 +225,29 @@ namespace return STATUS_SUCCESS; } - NTSTATUS ioctl_send_datagram(const io_device_context& c) const + NTSTATUS ioctl_send_datagram(windows_emulator& win_emu, const io_device_context& c) { + auto& emu = win_emu.emu(); + if (c.input_buffer_length < sizeof(AFD_SEND_DATAGRAM_INFO)) { return STATUS_BUFFER_TOO_SMALL; } - const auto send_info = emulator_object{c.emu, c.input_buffer}.read(); - const auto buffer = emulator_object{c.emu, send_info.BufferArray}.read(0); + const auto send_info = emu.read_memory(c.input_buffer); + const auto buffer = emu.read_memory(send_info.BufferArray); std::vector address{}; address.resize(send_info.TdiConnInfo.RemoteAddressLength); - c.emu.read_memory(reinterpret_cast(send_info.TdiConnInfo.RemoteAddress), address.data(), - address.size()); + emu.read_memory(reinterpret_cast(send_info.TdiConnInfo.RemoteAddress), address.data(), + address.size()); const network::address target(reinterpret_cast(address.data()), static_cast(address.size())); std::vector data{}; data.resize(buffer.len); - c.emu.read_memory(reinterpret_cast(buffer.buf), data.data(), data.size()); + emu.read_memory(reinterpret_cast(buffer.buf), data.data(), data.size()); const auto sent_data = sendto(*this->s, reinterpret_cast(data.data()), static_cast(data.size()), 0 /* ? */, &target.get_addr(), @@ -215,7 +255,15 @@ namespace if (sent_data < 0) { - return STATUS_CONNECTION_REFUSED; + const auto error = GET_SOCKET_ERROR(); + if (error == SOCK_WOULDBLOCK) + { + this->in_poll = false; + this->delayed_ioctl = c; + return STATUS_PENDING; + } + + return STATUS_UNSUCCESSFUL; } if (c.io_status_block) diff --git a/src/windows-emulator/handles.hpp b/src/windows-emulator/handles.hpp index 11c1f09..31cd553 100644 --- a/src/windows-emulator/handles.hpp +++ b/src/windows-emulator/handles.hpp @@ -104,8 +104,19 @@ class handle_store using index_type = uint32_t; using value_map = std::map; + bool block_mutation(bool blocked) + { + std::swap(this->block_mutation_, blocked); + return blocked; + } + handle store(T value) { + if (this->block_mutation_) + { + throw std::runtime_error("Mutation of handle store is blocked!"); + } + auto index = this->find_free_index(); this->store_.emplace(index, std::move(value)); @@ -159,6 +170,11 @@ class handle_store bool erase(const typename value_map::iterator& entry) { + if (this->block_mutation_) + { + throw std::runtime_error("Mutation of handle store is blocked!"); + } + if (entry == this->store_.end()) { return false; @@ -203,11 +219,13 @@ class handle_store void serialize(utils::buffer_serializer& buffer) const { + buffer.write(this->block_mutation_); buffer.write_map(this->store_); } void deserialize(utils::buffer_deserializer& buffer) { + buffer.read(this->block_mutation_); buffer.read_map(this->store_); } @@ -305,7 +323,7 @@ class handle_store return index; } - + bool block_mutation_{false}; value_map store_{}; }; diff --git a/src/windows-emulator/io_device.cpp b/src/windows-emulator/io_device.cpp index f28cbae..2714491 100644 --- a/src/windows-emulator/io_device.cpp +++ b/src/windows-emulator/io_device.cpp @@ -5,7 +5,7 @@ namespace { struct dummy_device : stateless_device { - NTSTATUS io_control(const io_device_context&) override + NTSTATUS io_control(windows_emulator&, const io_device_context&) override { return STATUS_SUCCESS; } diff --git a/src/windows-emulator/io_device.hpp b/src/windows-emulator/io_device.hpp index a7d36ec..1e27560 100644 --- a/src/windows-emulator/io_device.hpp +++ b/src/windows-emulator/io_device.hpp @@ -12,10 +12,6 @@ struct process_context; struct io_device_context { - windows_emulator& win_emu; - x64_emulator& emu; - process_context& proc; - handle event; emulator_pointer /*PIO_APC_ROUTINE*/ apc_routine; emulator_pointer apc_context; @@ -29,10 +25,6 @@ struct io_device_context struct io_device_creation_data { - windows_emulator& win_emu; - x64_emulator& emu; - process_context& proc; - uint64_t buffer; uint32_t length; }; @@ -48,20 +40,46 @@ struct io_device io_device(const io_device&) = delete; io_device& operator=(const io_device&) = delete; - virtual NTSTATUS io_control(const io_device_context& context) = 0; + virtual NTSTATUS io_control(windows_emulator& win_emu, const io_device_context& context) = 0; - virtual void create(const io_device_creation_data& data) + virtual void create(windows_emulator& win_emu, const io_device_creation_data& data) { + (void)win_emu; (void)data; } + virtual void work(windows_emulator& win_emu) + { + (void)win_emu; + } + virtual void serialize(utils::buffer_serializer& buffer) const = 0; virtual void deserialize(utils::buffer_deserializer& buffer) = 0; + + NTSTATUS execute_ioctl(windows_emulator& win_emu, const io_device_context& c) + { + if (c.io_status_block) + { + c.io_status_block.write({}); + } + + const auto result = this->io_control(win_emu, c); + + if (c.io_status_block) + { + c.io_status_block.access([&](IO_STATUS_BLOCK& status) + { + status.Status = result; + }); + } + + return result; + } }; struct stateless_device : io_device { - void create(const io_device_creation_data&) final + void create(windows_emulator&, const io_device_creation_data&) final { } @@ -81,17 +99,23 @@ class io_device_container : public io_device public: io_device_container() = default; - io_device_container(std::wstring device, const io_device_creation_data& data) + io_device_container(std::wstring device, windows_emulator& win_emu, const io_device_creation_data& data) : device_name_(std::move(device)) { this->setup(); - this->device_->create(data); + this->device_->create(win_emu, data); + } + + NTSTATUS io_control(windows_emulator& win_emu, const io_device_context& context) override + { + this->assert_validity(); + return this->device_->io_control(win_emu, context); } - NTSTATUS io_control(const io_device_context& context) override + void work(windows_emulator& win_emu) override { this->assert_validity(); - return this->device_->io_control(context); + return this->device_->work(win_emu); } void serialize(utils::buffer_serializer& buffer) const override diff --git a/src/windows-emulator/syscalls.cpp b/src/windows-emulator/syscalls.cpp index f1ebce7..b7a476c 100644 --- a/src/windows-emulator/syscalls.cpp +++ b/src/windows-emulator/syscalls.cpp @@ -1666,9 +1666,6 @@ namespace } const io_device_context context{ - .win_emu = c.win_emu, - .emu = c.emu, - .proc = c.proc, .event = event, .apc_routine = apc_routine, .apc_context = apc_context, @@ -1680,7 +1677,7 @@ namespace .output_buffer_length = output_buffer_length, }; - return device->io_control(context); + return device->execute_ioctl(c.win_emu, context); } NTSTATUS handle_NtQueryWnfStateData() @@ -2073,15 +2070,12 @@ namespace if (filename.starts_with(device_prefix)) { const io_device_creation_data data{ - .win_emu = c.win_emu, - .emu = c.emu, - .proc = c.proc, .buffer = ea_buffer, .length = ea_length, }; auto device_name = filename.substr(device_prefix.size()); - io_device_container container{std::move(device_name), data}; + io_device_container container{std::move(device_name), c.win_emu, data}; const auto handle = c.proc.devices.store(std::move(container)); file_handle.write(handle); @@ -2412,15 +2406,12 @@ namespace { if (alertable) { - puts("Alertable NtWaitForSingleObject not supported yet!"); - c.emu.stop(); - return STATUS_NOT_SUPPORTED; + c.win_emu.logger.print(color::gray, "Alertable NtWaitForSingleObject not supported yet!\n"); } if (h.value.type != handle_types::thread && h.value.type != handle_types::event) { - puts("Unsupported handle type for NtWaitForSingleObject!"); - c.emu.stop(); + c.win_emu.logger.print(color::gray, "Unsupported handle type for NtWaitForSingleObject!\n"); return STATUS_NOT_SUPPORTED; } diff --git a/src/windows-emulator/windows_emulator.cpp b/src/windows-emulator/windows_emulator.cpp index 99425db..410fa93 100644 --- a/src/windows-emulator/windows_emulator.cpp +++ b/src/windows-emulator/windows_emulator.cpp @@ -3,6 +3,7 @@ #include "context_frame.hpp" #include +#include constexpr auto MAX_INSTRUCTIONS_PER_TIME_SLICE = 100000; @@ -113,6 +114,7 @@ namespace kusd.XState.EnabledFeatures = 0x000000000000001f; kusd.XState.EnabledVolatileFeatures = 0x000000000000000f; kusd.XState.Size = 0x000003c0; + kusd.QpcFrequency = 1000; constexpr std::wstring_view root_dir{L"C:\\WINDOWS"}; memcpy(&kusd.NtSystemRoot.arr[0], root_dir.data(), root_dir.size() * 2); @@ -496,8 +498,28 @@ namespace dispatch_exception_pointers(emu, dispatcher, pointers); } - bool switch_to_thread(const logger& logger, x64_emulator& emu, process_context& context, emulator_thread& thread) + void perform_context_switch_work(windows_emulator& win_emu) { + auto& devices = win_emu.process().devices; + + // Crappy mechanism to prevent mutation while iterating. + const auto was_blocked = devices.block_mutation(true); + const auto _ = utils::finally([&] + { + devices.block_mutation(was_blocked); + }); + + for (auto& device : devices) + { + device.second.work(win_emu); + } + } + + bool switch_to_thread(windows_emulator& win_emu, emulator_thread& thread) + { + auto& emu = win_emu.emu(); + auto& context = win_emu.process(); + if (!thread.is_thread_ready(context)) { return false; @@ -511,11 +533,9 @@ namespace return true; } - logger.print(color::green, "Performing thread switch...\n"); - - if (active_thread) { + win_emu.logger.print(color::green, "Performing thread switch...\n"); active_thread->save(emu); } @@ -528,26 +548,30 @@ namespace return true; } - bool switch_to_thread(const logger& logger, x64_emulator& emu, process_context& context, const handle thread_handle) + bool switch_to_thread(windows_emulator& win_emu, const handle thread_handle) { - auto* thread = context.threads.get(thread_handle); + auto* thread = win_emu.process().threads.get(thread_handle); if (!thread) { throw std::runtime_error("Bad thread handle"); } - return switch_to_thread(logger, emu, context, *thread); + return switch_to_thread(win_emu, *thread); } - bool switch_to_next_thread(const logger& logger, x64_emulator& emu, process_context& context) + bool switch_to_next_thread(windows_emulator& win_emu) { + perform_context_switch_work(win_emu); + + auto& context = win_emu.process(); + bool next_thread = false; for (auto& thread : context.threads) { if (next_thread) { - if (switch_to_thread(logger, emu, context, thread.second)) + if (switch_to_thread(win_emu, thread.second)) { return true; } @@ -563,7 +587,7 @@ namespace for (auto& thread : context.threads) { - if (switch_to_thread(logger, emu, context, thread.second)) + if (switch_to_thread(win_emu, thread.second)) { return true; } @@ -583,10 +607,10 @@ namespace case handle_types::event: { - const auto* e = c.events.get(h); + auto* e = c.events.get(h); if (e) { - return e->signaled; + return e->is_signaled(); } break; @@ -782,13 +806,13 @@ void windows_emulator::setup_process(const emulator_settings& settings) context.default_register_set = emu.save_registers(); const auto main_thread_id = context.create_thread(emu, context.executable->entry_point, 0, 0); - switch_to_thread(this->logger, emu, context, main_thread_id); + switch_to_thread(*this, main_thread_id); } void windows_emulator::perform_thread_switch() { this->switch_thread = false; - while (!switch_to_next_thread(this->logger, this->emu(), this->process())) + while (!switch_to_next_thread(*this)) { // TODO: Optimize that std::this_thread::sleep_for(1ms);