Skip to content

Commit

Permalink
Move peer management routines to peer.cpp
Browse files Browse the repository at this point in the history
Signed-off-by: Lev Stipakov <[email protected]>
  • Loading branch information
lstipakov committed Nov 1, 2024
1 parent de9646b commit f4963ca
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 241 deletions.
211 changes: 3 additions & 208 deletions Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ OvpnStopVPN(_In_ POVPN_DEVICE device)
{
LOG_ENTER();

OvpnFlushPeers(device);
OvpnCleanupPeerTable(device, &device->PeersByVpn6);
OvpnCleanupPeerTable(device, &device->PeersByVpn4);
OvpnCleanupPeerTable(device, &device->Peers);

KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock);
PWSK_SOCKET socket = device->Socket.Socket;
Expand Down Expand Up @@ -762,210 +764,3 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) {

return status;
}

_Use_decl_annotations_
NTSTATUS
OvpnAddPeerToTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer)
{
NTSTATUS status;
BOOLEAN newElem;

auto irql = ExAcquireSpinLockExclusive(&device->SpinLock);

RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem);

if (newElem) {
status = STATUS_SUCCESS;
InterlockedIncrement(&peer->RefCounter);
}
else {
LOG_ERROR("Unable to add new peer");
status = STATUS_NO_MEMORY;
}

ExReleaseSpinLockExclusive(&device->SpinLock, irql);

return status;
}

_Use_decl_annotations_
VOID
OvpnFlushPeers(POVPN_DEVICE device) {
OvpnCleanupPeerTable(device, &device->PeersByVpn6);
OvpnCleanupPeerTable(device, &device->PeersByVpn4);
OvpnCleanupPeerTable(device, &device->Peers);
}

_Use_decl_annotations_
VOID
OvpnCleanupPeerTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* peers)
{
auto irql = ExAcquireSpinLockExclusive(&device->SpinLock);

while (!RtlIsGenericTableEmpty(peers)) {
PVOID ptr = RtlGetElementGenericTable(peers, 0);
OvpnPeerContext* peer = *(OvpnPeerContext**)ptr;
RtlDeleteElementGenericTable(peers, ptr);

OvpnPeerCtxRelease(peer);
}

ExReleaseSpinLockExclusive(&device->SpinLock, irql);
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnGetFirstPeer(POVPN_DEVICE device)
{
auto irql = ExAcquireSpinLockShared(&device->SpinLock);

OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
OvpnPeerContext* peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;

if (peer != nullptr) {
InterlockedIncrement(&peer->RefCounter);
}

ExReleaseSpinLockShared(&device->SpinLock, irql);

return peer;
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId)
{
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
}
else {
OvpnPeerContext p{};
p.PeerId = PeerId;

auto* pp = &p;
ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp);
}

peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;

if (peer) {
InterlockedIncrement(&peer->RefCounter);
}

ExReleaseSpinLockShared(&device->SpinLock, kirql);

return peer;
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr)
{
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
}
else {
OvpnPeerContext p{};
p.VpnAddrs.IPv4 = addr;

auto* pp = &p;
ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp);
}

peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;
if (peer) {
InterlockedIncrement(&peer->RefCounter);
}

ExReleaseSpinLockShared(&device->SpinLock, kirql);

return peer;
}

_Use_decl_annotations_
OvpnPeerContext*
OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr)
{
OvpnPeerContext* peer = nullptr;
OvpnPeerContext** ptr = nullptr;

auto kirql = ExAcquireSpinLockShared(&device->SpinLock);

if (device->Mode == OVPN_MODE_P2P) {
ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0);
}
else {
OvpnPeerContext p{};
p.VpnAddrs.IPv6 = addr;

auto* pp = &p;
ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp);
}

peer = ptr ? (OvpnPeerContext*)*ptr : nullptr;
if (peer) {
InterlockedIncrement(&peer->RefCounter);
}

ExReleaseSpinLockShared(&device->SpinLock, kirql);

return peer;
}

VOID
OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName)
{
auto peerId = peer->PeerId;
auto pp = &peer;

auto kirql = ExAcquireSpinLockExclusive(&device->SpinLock);

if (RtlDeleteElementGenericTable(table, pp)) {
LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id"));

if (InterlockedDecrement(&peer->RefCounter) == 0) {
OvpnPeerCtxFree(peer);
LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id"));
}
}
else {
LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id"));
}

ExReleaseSpinLockExclusive(&device->SpinLock, kirql);
}

_Use_decl_annotations_
NTSTATUS
OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId)
{
NTSTATUS status = STATUS_SUCCESS;

LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id"));

// get peer from main table
OvpnPeerContext* peer = OvpnFindPeer(device, peerId);
if (peer == NULL) {
status = STATUS_NOT_FOUND;
LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id"));
}
else {
OvpnDeletePeerFromTable(device, &device->PeersByVpn4, peer, "vpn4");
OvpnDeletePeerFromTable(device, &device->PeersByVpn6, peer, "vpn6");
OvpnDeletePeerFromTable(device, &device->Peers, peer, "peers");

OvpnPeerCtxRelease(peer);
}

return status;
}
32 changes: 0 additions & 32 deletions Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,35 +107,3 @@ struct OVPN_DEVICE {
typedef OVPN_DEVICE * POVPN_DEVICE;

WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DEVICE, OvpnGetDeviceContext)

struct OvpnPeerContext;

_Must_inspect_result_
NTSTATUS
OvpnAddPeerToTable(POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer);

VOID
OvpnFlushPeers(_In_ POVPN_DEVICE device);

VOID
OvpnCleanupPeerTable(_In_ POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE*);

_Must_inspect_result_
OvpnPeerContext*
OvpnGetFirstPeer(_In_ POVPN_DEVICE device);

_Must_inspect_result_
OvpnPeerContext*
OvpnFindPeer(_In_ POVPN_DEVICE device, INT32 PeerId);

_Must_inspect_result_
OvpnPeerContext*
OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr);

_Must_inspect_result_
OvpnPeerContext*
OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr);

_Must_inspect_result_
NTSTATUS
OvpnDeletePeer(_In_ POVPN_DEVICE device, INT32 peerId);
Loading

0 comments on commit f4963ca

Please sign in to comment.