Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add poll support #25

Merged
merged 4 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 265 additions & 25 deletions src/windows-emulator/devices/afd_endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "afd_types.hpp"

#include "../windows_emulator.hpp"
#include "../syscall_utils.hpp"

#include <network/address.hpp>
#include <network/socket.hpp>
Expand Down Expand Up @@ -31,11 +32,159 @@ namespace
return win_emu.emu().read_memory<afd_creation_data>(data.buffer);
}

std::pair<AFD_POLL_INFO, std::vector<AFD_POLL_HANDLE_INFO>> get_poll_info(
windows_emulator& win_emu, const io_device_context& c)
{
constexpr auto info_size = offsetof(AFD_POLL_INFO, Handles);
if (!c.input_buffer || c.input_buffer_length < info_size)
{
throw std::runtime_error("Bad AFD poll data");
}

AFD_POLL_INFO poll_info{};
win_emu.emu().read_memory(c.input_buffer, &poll_info, info_size);

std::vector<AFD_POLL_HANDLE_INFO> handle_info{};

const emulator_object<AFD_POLL_HANDLE_INFO> handle_info_obj{win_emu.emu(), c.input_buffer + info_size};

if (c.input_buffer_length < (info_size + sizeof(AFD_POLL_HANDLE_INFO) * poll_info.NumberOfHandles))
{
throw std::runtime_error("Bad AFD poll handle data");
}

for (ULONG i = 0; i < poll_info.NumberOfHandles; ++i)
{
handle_info.emplace_back(handle_info_obj.read(i));
}

return {std::move(poll_info), std::move(handle_info)};
}

int16_t map_afd_request_events_to_socket(const ULONG poll_events)
{
int16_t socket_events{};

if (poll_events & (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE))
{
socket_events |= POLLRDNORM;
}

if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
{
socket_events |= POLLRDNORM;
}

if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
{
socket_events |= POLLRDBAND;
}

if (poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND))
{
socket_events |= POLLWRNORM;
}

return socket_events;
}

ULONG map_socket_response_events_to_afd(const int16_t socket_events)
{
ULONG afd_events = 0;

if (socket_events & POLLRDNORM)
{
afd_events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE);
}

if (socket_events & POLLRDBAND)
{
afd_events |= AFD_POLL_RECEIVE_EXPEDITED;
}

if (socket_events & POLLWRNORM)
{
afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND);
}

if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR))
{
afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT);
}
else if (socket_events & POLLHUP)
{
afd_events |= AFD_POLL_DISCONNECT;
}

if (socket_events & POLLNVAL)
{
afd_events |= AFD_POLL_LOCAL_CLOSE;
}

return afd_events;
}

NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c,
const std::span<const SOCKET> endpoints,
const std::span<const AFD_POLL_HANDLE_INFO> handles)
{
std::vector<pollfd> poll_data{};
poll_data.resize(endpoints.size());

for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i)
{
auto& pfd = poll_data.at(i);
auto& handle = handles[i];

pfd.fd = endpoints[i];
pfd.events = map_afd_request_events_to_socket(handle.PollEvents);
pfd.revents = pfd.events;
}

const auto count = poll(poll_data.data(), static_cast<uint32_t>(poll_data.size()), 0);
if (count <= 0)
{
return STATUS_PENDING;
}

constexpr auto info_size = offsetof(AFD_POLL_INFO, Handles);
const emulator_object<AFD_POLL_HANDLE_INFO> handle_info_obj{win_emu.emu(), c.input_buffer + info_size};

size_t current_index = 0;

for (size_t i = 0; i < endpoints.size(); ++i)
{
const auto& pfd = poll_data.at(i);
if (pfd.revents == 0)
{
continue;
}

auto entry = handle_info_obj.read(i);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents);
entry.Status = STATUS_SUCCESS;

handle_info_obj.write(entry, current_index++);
break;
}

assert(current_index == static_cast<size_t>(count));

emulator_object<AFD_POLL_INFO>{win_emu.emu(), c.input_buffer}.access([&](AFD_POLL_INFO& info)
{
info.NumberOfHandles = static_cast<ULONG>(current_index);
});

return STATUS_SUCCESS;
}

struct afd_endpoint : io_device
{
bool in_poll{};
std::optional<SOCKET> s{};
std::optional<io_device_context> delayed_ioctl{};
bool executing_delayed_ioctl_{};
std::optional<SOCKET> s_{};
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};

afd_endpoint()
{
Expand All @@ -47,9 +196,9 @@ namespace

~afd_endpoint() override
{
if (this->s)
if (this->s_)
{
closesocket(*this->s);
closesocket(*this->s_);
}
}

Expand All @@ -65,31 +214,70 @@ namespace

network::socket::set_blocking(sock, false);

s = sock;
s_ = sock;
}

void work(windows_emulator& win_emu) override
void delay_ioctrl(const io_device_context& c,
const std::optional<std::chrono::steady_clock::time_point> timeout = {},
const std::optional<bool> require_poll = {})
{
if (!this->delayed_ioctl || !this->s)
if (this->executing_delayed_ioctl_)
{
return;
}

const auto is_ready = network::socket::is_socket_ready(*this->s, this->in_poll);
if (!is_ready)
this->timeout_ = timeout;
this->require_poll_ = require_poll;
this->delayed_ioctl_ = c;
}

void clear_pending_state()
{
this->timeout_ = {};
this->require_poll_ = {};
this->delayed_ioctl_ = {};
}

void work(windows_emulator& win_emu) override
{
if (!this->delayed_ioctl_ || !this->s_)
{
return;
}

this->execute_ioctl(win_emu, *this->delayed_ioctl);
this->executing_delayed_ioctl_ = true;
const auto _ = utils::finally([&]
{
this->executing_delayed_ioctl_ = false;
});

if (this->require_poll_.has_value())
{
const auto is_ready = network::socket::is_socket_ready(*this->s_, *this->require_poll_);
if (!is_ready)
{
return;
}
}

const auto status = this->execute_ioctl(win_emu, *this->delayed_ioctl_);
if (status == STATUS_PENDING)
{
if (!this->timeout_ || this->timeout_ > std::chrono::steady_clock::now())
{
return;
}

write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT);
}

auto* e = win_emu.process().events.get(this->delayed_ioctl->event);
auto* e = win_emu.process().events.get(this->delayed_ioctl_->event);
if (e)
{
e->signaled = true;
}

this->delayed_ioctl = {};
this->clear_pending_state();
}

void deserialize(utils::buffer_deserializer&) override
Expand Down Expand Up @@ -122,14 +310,15 @@ namespace
return this->ioctl_send_datagram(win_emu, c);
case AFD_RECEIVE_DATAGRAM:
return this->ioctl_receive_datagram(win_emu, c);
case AFD_POLL:
return this->ioctl_poll(win_emu, c);
case AFD_SET_CONTEXT:
return STATUS_SUCCESS;
case AFD_GET_INFORMATION:
return STATUS_SUCCESS;
default:
win_emu.logger.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code);
return STATUS_NOT_SUPPORTED;
}

win_emu.logger.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code);
return STATUS_NOT_SUPPORTED;
}

NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const
Expand All @@ -148,14 +337,67 @@ namespace

const network::address addr(address, address_size);

if (bind(*this->s, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR)
if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR)
{
return STATUS_ADDRESS_ALREADY_ASSOCIATED;
}

return STATUS_SUCCESS;
}

static std::vector<SOCKET> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO> handles)
{
auto& proc = win_emu.process();

std::vector<SOCKET> endpoints{};
endpoints.reserve(handles.size());

for (const auto& handle : handles)
{
auto* device = proc.devices.get(reinterpret_cast<uint64_t>(handle.Handle));
if (!device)
{
throw std::runtime_error("Bad device!");
}

const auto* endpoint = device->get_internal_device<afd_endpoint>();
if (!endpoint)
{
throw std::runtime_error("Device is not an AFD endpoint!");
}

endpoints.push_back(*endpoint->s_);
}

return endpoints;
}

NTSTATUS ioctl_poll(windows_emulator& win_emu, const io_device_context& c)
{
const auto [info, handles] = get_poll_info(win_emu, c);
const auto endpoints = resolve_endpoints(win_emu, handles);

const auto status = perform_poll(win_emu, c, endpoints, handles);
if (status != STATUS_PENDING)
{
return status;
}

if (!this->executing_delayed_ioctl_)
{
std::optional<std::chrono::steady_clock::time_point> timeout{};
if (info.Timeout.QuadPart)
{
timeout = convert_delay_interval_to_time_point(info.Timeout);
}

this->delay_ioctrl(c, timeout);
}

return STATUS_PENDING;
}

NTSTATUS ioctl_receive_datagram(windows_emulator& win_emu, const io_device_context& c)
{
auto& emu = win_emu.emu();
Expand Down Expand Up @@ -188,16 +430,15 @@ namespace
std::vector<char> data{};
data.resize(buffer.len);

const auto recevied_data = recvfrom(*this->s, data.data(), static_cast<int>(data.size()), 0,
const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast<int>(data.size()), 0,
reinterpret_cast<sockaddr*>(address.data()), &fromlength);

if (recevied_data < 0)
{
const auto error = GET_SOCKET_ERROR();
if (error == SOCK_WOULDBLOCK)
{
this->in_poll = true;
this->delayed_ioctl = c;
this->delay_ioctrl(c, {}, true);
return STATUS_PENDING;
}

Expand Down Expand Up @@ -225,7 +466,7 @@ namespace

NTSTATUS ioctl_send_datagram(windows_emulator& win_emu, const io_device_context& c)
{
auto& emu = win_emu.emu();
const auto& emu = win_emu.emu();

if (c.input_buffer_length < sizeof(AFD_SEND_DATAGRAM_INFO))
{
Expand All @@ -243,7 +484,7 @@ namespace

const auto data = emu.read_memory(buffer.buf, buffer.len);

const auto sent_data = sendto(*this->s, reinterpret_cast<const char*>(data.data()),
const auto sent_data = sendto(*this->s_, reinterpret_cast<const char*>(data.data()),
static_cast<int>(data.size()), 0 /* ? */, &target.get_addr(),
target.get_size());

Expand All @@ -252,8 +493,7 @@ namespace
const auto error = GET_SOCKET_ERROR();
if (error == SOCK_WOULDBLOCK)
{
this->in_poll = false;
this->delayed_ioctl = c;
this->delay_ioctrl(c, {}, false);
return STATUS_PENDING;
}

Expand Down
Loading