Skip to content

Commit

Permalink
Account for SOCKADDR in control packets
Browse files Browse the repository at this point in the history
In multipeer UDP mode, we expect userspace to prepend CC packets
with SOCKADDR to know where to send the control packet. Likewise,
when we receive the control packet, we prepend it with remote SOCKADDR
before pushing to userspace.

#84

Co-authored-by: Leon Dang <[email protected]>

Signed-off-by: Leon Dang <[email protected]>
Signed-off-by: Lev Stipakov <[email protected]>
  • Loading branch information
lstipakov committed Sep 18, 2024
1 parent 71e0a20 commit 72fc278
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 35 deletions.
62 changes: 49 additions & 13 deletions Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,39 +187,75 @@ OvpnEvtIoWrite(WDFQUEUE queue, WDFREQUEST request, size_t length)
// acquire spinlock, since we access device->TransportSocket
KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock);

OVPN_TX_BUFFER* buffer = NULL;
OVPN_TX_BUFFER* txBuf = NULL;

if (device->Socket.Socket == NULL) {
status = STATUS_INVALID_DEVICE_STATE;
LOG_ERROR("TransportSocket is not initialized");
goto error;
}

// fetch tx buffer
GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &buffer));

// get request buffer
PVOID requestBuffer;
size_t requestBufferLength;
GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestRetrieveInputBuffer(request, 0, &requestBuffer, &requestBufferLength));
PVOID buf;
size_t bufLen;
GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestRetrieveInputBuffer(request, 0, &buf, &bufLen));

PSOCKADDR sa = NULL;

if (device->Mode == OVPN_MODE_MP) {
// buffer is prepended with SOCKADDR

sa = (PSOCKADDR)buf;
switch (sa->sa_family) {
case AF_INET:
if (bufLen <= sizeof(SOCKADDR_IN)) {
status = STATUS_INVALID_MESSAGE;
LOG_ERROR("Message too short", TraceLoggingValue(bufLen, "msgLen"), TraceLoggingValue(sizeof(SOCKADDR_IN), "minLen"));
goto error;
}

buf = (char*)buf + sizeof(SOCKADDR_IN);
bufLen -= sizeof(SOCKADDR_IN);
break;

case AF_INET6:
if (bufLen <= sizeof(SOCKADDR_IN6)) {
status = STATUS_INVALID_MESSAGE;
LOG_ERROR("Message too short", TraceLoggingValue(bufLen, "msgLen"), TraceLoggingValue(sizeof(SOCKADDR_IN6), "minLen"));
goto error;
}

buf = (char*)buf + sizeof(SOCKADDR_IN6);
bufLen -= sizeof(SOCKADDR_IN6);
break;

default:
LOG_ERROR("Invalid address family", TraceLoggingValue(sa->sa_family, "AF"));
status = STATUS_INVALID_ADDRESS;
goto error;
}
}

// fetch tx buffer
GOTO_IF_NOT_NT_SUCCESS(error, status, OvpnTxBufferPoolGet(device->TxBufferPool, &txBuf));

// copy data from request to tx buffer
PUCHAR buf = OvpnTxBufferPut(buffer, requestBufferLength);
RtlCopyMemory(buf, requestBuffer, requestBufferLength);
PUCHAR data = OvpnTxBufferPut(txBuf, bufLen);
RtlCopyMemory(data, buf, bufLen);

buffer->IoQueue = device->PendingWritesQueue;
txBuf->IoQueue = device->PendingWritesQueue;

// move request to manual queue
GOTO_IF_NOT_NT_SUCCESS(error, status, WdfRequestForwardToIoQueue(request, device->PendingWritesQueue));

// send
LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer));
LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, txBuf, sa));

goto done_not_complete;

error:
if (buffer != NULL) {
OvpnTxBufferPoolPut(buffer);
if (txBuf != NULL) {
OvpnTxBufferPoolPut(txBuf);
}

ULONG_PTR bytesCopied = 0;
Expand Down
69 changes: 51 additions & 18 deletions socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,29 @@ OvpnSocketSyncOp(_In_z_ CHAR* opName, OP op, SUCCESS success)
static
_Requires_shared_lock_held_(device->SpinLock)
VOID
OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len)
OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR buf, SIZE_T len, _In_opt_ PSOCKADDR remote)
{
SIZE_T hdrLen = 0, totalLen = len;

// in UDP and MP mode we prepend CC packet with remote sockaddr before pushing it to userspace
if (device->Mode == OVPN_MODE_MP && remote != NULL) {
switch (remote->sa_family) {
case AF_INET:
hdrLen = sizeof(SOCKADDR_IN);
break;

case AF_INET6:
hdrLen = sizeof(SOCKADDR_IN6);
break;

default:
LOG_ERROR("Invalid remote address family", TraceLoggingValue(remote->sa_family, "AF"));
InterlockedIncrementNoFence(&device->Stats.LostInControlPackets);
return;
}
totalLen += hdrLen;
}

WDFREQUEST request;
NTSTATUS status = WdfIoQueueRetrieveNextRequest(device->PendingReadsQueue, &request);
if (!NT_SUCCESS(status)) {
Expand All @@ -113,17 +134,22 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR
return;
}

if (sizeof(buffer->Data) >= len) {
// copy control packet to buffer
RtlCopyMemory(buffer->Data, buf, len);
buffer->Len = len;
if (sizeof(buffer->Data) >= totalLen) {
if (hdrLen > 0) {
// prepend with sockaddr
RtlCopyMemory(buffer->Data, remote, hdrLen);
}

// copy control packet payload
RtlCopyMemory(buffer->Data + hdrLen, buf, totalLen - hdrLen);
buffer->Len = totalLen;

// enqueue buffer, it will be dequeued when read request arrives
OvpnBufferQueueEnqueue(device->ControlRxBufferQueue, &buffer->QueueListEntry);
}
else {
LOG_ERROR("Buffer too small, packet len <pktlen>, buf len <buflen>",
TraceLoggingValue(len, "pktlen"), TraceLoggingValue(sizeof(buffer->Data), "buflen"));
TraceLoggingValue(totalLen, "pktlen"), TraceLoggingValue(sizeof(buffer->Data), "buflen"));

OvpnRxBufferPoolPut(buffer);
}
Expand All @@ -133,19 +159,26 @@ OvpnSocketControlPacketReceived(_In_ POVPN_DEVICE device, _In_reads_(len) PUCHAR
PVOID readBuffer;
size_t readBufferLength;

ULONG_PTR bytesSent = len;
ULONG_PTR bytesSent = totalLen;

LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, len, &readBuffer, &readBufferLength));
LOG_IF_NOT_NT_SUCCESS(status = WdfRequestRetrieveOutputBuffer(request, totalLen, &readBuffer, &readBufferLength));
if (NT_SUCCESS(status)) {
// copy control packet to read request buffer
RtlCopyMemory(readBuffer, buf, len);

if (hdrLen > 0) {
// prepend with sockaddr
RtlCopyMemory(readBuffer, remote, hdrLen);
}

// copy control packet payload
RtlCopyMemory((PCHAR)readBuffer + hdrLen, buf, totalLen - hdrLen);

InterlockedIncrementNoFence(&device->Stats.ReceivedControlPackets);
} else {
InterlockedIncrementNoFence(&device->Stats.LostInControlPackets);

if (status == STATUS_BUFFER_TOO_SMALL) {
LOG_ERROR("Buffer too small, packet len <pktlen>, buf len <buflen>",
TraceLoggingValue(len, "pktlen"), TraceLoggingValue(readBufferLength, "buflen"));
TraceLoggingValue(totalLen, "pktlen"), TraceLoggingValue(readBufferLength, "buflen"));
}

bytesSent = 0;
Expand Down Expand Up @@ -238,7 +271,7 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_
}

VOID
OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch)
OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLength) PUCHAR buf, SIZE_T packetLength, BOOLEAN irqlDispatch, _In_opt_ PSOCKADDR remoteAddr)
{
// If we're at dispatch level, we can use a small optimization and use function
// which is not calling KeRaiseIRQL to raise the IRQL to DISPATCH_LEVEL before attempting to acquire the lock
Expand All @@ -255,7 +288,7 @@ OvpnSocketProcessIncomingPacket(_In_ POVPN_DEVICE device, _In_reads_(packetLengt
OvpnSocketDataPacketReceived(device, op, buf, packetLength);
}
else {
OvpnSocketControlPacketReceived(device, buf, packetLength);
OvpnSocketControlPacketReceived(device, buf, packetLength, remoteAddr);
}

// don't forget to release spinlock
Expand Down Expand Up @@ -330,7 +363,7 @@ OvpnSocketUdpReceiveFromEvent(_In_ PVOID socketContext, ULONG flags, _In_opt_ PW
buf = packetBuf;
}

OvpnSocketProcessIncomingPacket(device, buf, dataIndication->Buffer.Length, flags & WSK_FLAG_AT_DISPATCH_LEVEL);
OvpnSocketProcessIncomingPacket(device, buf, dataIndication->Buffer.Length, flags & WSK_FLAG_AT_DISPATCH_LEVEL, dataIndication->RemoteAddress);

dataIndication = dataIndication->Next;
}
Expand Down Expand Up @@ -412,7 +445,7 @@ OvpnSocketTcpReceiveEvent(_In_opt_ PVOID socketContext, _In_ ULONG flags, _In_op
buf = tcpState->PacketBuf;
}

OvpnSocketProcessIncomingPacket(device, buf, tcpState->PacketLength, flags & WSK_FLAG_AT_DISPATCH_LEVEL);
OvpnSocketProcessIncomingPacket(device, buf, tcpState->PacketLength, flags & WSK_FLAG_AT_DISPATCH_LEVEL, NULL);

mdlDataLen -= bytesRemained;
dataIndicationLen -= bytesRemained;
Expand Down Expand Up @@ -704,7 +737,7 @@ OvpnSocketSendComplete(_In_ PDEVICE_OBJECT deviceObj, _In_ PIRP irp, _In_ PVOID

NTSTATUS
_Use_decl_annotations_
OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer) {
OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer, SOCKADDR* sa) {
OVPN_DEVICE* device = (OVPN_DEVICE*)OvpnTxBufferPoolGetContext(buffer->Pool);

PWSK_SOCKET socket = ovpnSocket->Socket;
Expand Down Expand Up @@ -742,11 +775,11 @@ OvpnSocketSend(OvpnSocket* ovpnSocket, OVPN_TX_BUFFER* buffer) {
}
else if (buffer->WskBufList.Buffer.Length != 0) {
PWSK_PROVIDER_DATAGRAM_DISPATCH datagramDispatch = (PWSK_PROVIDER_DATAGRAM_DISPATCH)socket->Dispatch;
LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendMessages(socket, &buffer->WskBufList, 0, NULL, 0, NULL, irp));
LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendMessages(socket, &buffer->WskBufList, 0, sa, 0, NULL, irp));
} else {
WSK_BUF wskBuf{ buffer->Mdl, FIELD_OFFSET(OVPN_TX_BUFFER, Head) + (ULONG)(buffer->Data - buffer->Head), buffer->Len };
PWSK_PROVIDER_DATAGRAM_DISPATCH datagramDispatch = (PWSK_PROVIDER_DATAGRAM_DISPATCH)socket->Dispatch;
LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendTo(socket, &wskBuf, 0, NULL, 0, NULL, irp));
LOG_IF_NOT_NT_SUCCESS(status = datagramDispatch->WskSendTo(socket, &wskBuf, 0, sa, 0, NULL, irp));
}

return status;
Expand Down
2 changes: 1 addition & 1 deletion socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ OvpnSocketClose(_In_opt_ PWSK_SOCKET socket);

_Must_inspect_result_
NTSTATUS
OvpnSocketSend(_In_ OvpnSocket* ovpnSocket, _In_ OVPN_TX_BUFFER* buffer);
OvpnSocketSend(_In_ OvpnSocket* ovpnSocket, _In_ OVPN_TX_BUFFER* buffer, _In_opt_ SOCKADDR* sa);

_Must_inspect_result_
NTSTATUS
Expand Down
2 changes: 1 addition & 1 deletion timer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static VOID OvpnTimerXmit(WDFTIMER timer)

if (NT_SUCCESS(status)) {
// start async send, completion handler will return ciphertext buffer to the pool
LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer));
LOG_IF_NOT_NT_SUCCESS(status = OvpnSocketSend(&device->Socket, buffer, NULL));
if (NT_SUCCESS(status)) {
LOG_INFO("Ping sent");
}
Expand Down
4 changes: 2 additions & 2 deletions txqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET
if (NT_SUCCESS(status)) {
// start async send, this will return ciphertext buffer to the pool
if (device->Socket.Tcp) {
status = OvpnSocketSend(&device->Socket, buffer);
status = OvpnSocketSend(&device->Socket, buffer, NULL);
}
else {
// for UDP we use SendMessages to send multiple datagrams at once
Expand Down Expand Up @@ -195,7 +195,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue)

if (!device->Socket.Tcp) {
// this will use WskSendMessages to send buffers list which we constructed before
LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead));
LOG_IF_NOT_NT_SUCCESS(OvpnSocketSend(&device->Socket, txBufferHead, NULL));
}
}
}
Expand Down

0 comments on commit 72fc278

Please sign in to comment.