Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proper timeouts to all computer api calls #154

Merged
merged 5 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 7 additions & 32 deletions src/CommonLib/Processors/ComputerSessionProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Win32;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundRPC;
using SharpHoundRPC.NetAPINative;

namespace SharpHoundCommonLib.Processors {
Expand Down Expand Up @@ -56,7 +57,7 @@ public async Task<SessionAPIResult> ReadUserSessions(string computerName, string

_log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName);

var apiTask = Task.Run(() => {
var result = await Task.Run(() => {
NetAPIResult<IEnumerable<NetSessionEnumResults>> result;
if (_doLocalAdminSessionEnum) {
// If we are authenticating using a local admin, we need to impersonate for this
Expand All @@ -77,24 +78,11 @@ public async Task<SessionAPIResult> ReadUserSessions(string computerName, string
}

return result;
});

if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) {
await SendComputerStatus(new CSVComputerStatus {
Status = "Timeout",
Task = "NetSessionEnum",
ComputerName = computerName
});
ret.Collected = false;
ret.FailureReason = "Timeout";
return ret;
}

var result = apiTask.Result;
}).TimeoutAfter(timeout);

if (result.IsFailed) {
await SendComputerStatus(new CSVComputerStatus {
Status = result.Status.ToString(),
Status = result.GetErrorStatus(),
Task = "NetSessionEnum",
ComputerName = computerName
});
Expand Down Expand Up @@ -186,7 +174,7 @@ public async Task<SessionAPIResult> ReadUserSessionsPrivileged(string computerNa

_log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName);

var apiTask = Task.Run(() => {
var result = await Task.Run(() => {
NetAPIResult<IEnumerable<NetWkstaUserEnumResults>>
result;
if (_doLocalAdminSessionEnum) {
Expand All @@ -208,24 +196,11 @@ public async Task<SessionAPIResult> ReadUserSessionsPrivileged(string computerNa
}

return result;
});
}).TimeoutAfter(timeout);

if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) {
await SendComputerStatus(new CSVComputerStatus {
Status = "Timeout",
Task = "NetWkstaUserEnum",
ComputerName = computerName
});
ret.Collected = false;
ret.FailureReason = "Timeout";
return ret;
}

var result = apiTask.Result;

if (result.IsFailed) {
await SendComputerStatus(new CSVComputerStatus {
Status = result.Status.ToString(),
Status = result.GetErrorStatus(),
Task = "NetWkstaUserEnum",
ComputerName = computerName
});
Expand Down
40 changes: 29 additions & 11 deletions src/CommonLib/Processors/LocalGroupProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundRPC;
using SharpHoundRPC.Shared;
using SharpHoundRPC.Wrappers;

Expand Down Expand Up @@ -33,7 +34,7 @@ public virtual SharpHoundRPC.Result<ISAMServer> OpenSamServer(string computerNam
return SharpHoundRPC.Result<ISAMServer>.Fail(result.SError);
}

return SharpHoundRPC.Result<ISAMServer>.Ok(result.Value);
return SharpHoundRPC.Result<ISAMServer>.Ok(result.Value);
}

public IAsyncEnumerable<LocalGroupAPIResult> GetLocalGroups(ResolvedSearchResult result)
Expand All @@ -48,12 +49,17 @@ public IAsyncEnumerable<LocalGroupAPIResult> GetLocalGroups(ResolvedSearchResult
/// <param name="computerObjectId">The objectsid of the computer in the domain</param>
/// <param name="computerDomain">The domain the computer belongs too</param>
/// <param name="isDomainController">Is the computer a domain controller</param>
/// <param name="timeout"></param>
/// <returns></returns>
public async IAsyncEnumerable<LocalGroupAPIResult> GetLocalGroups(string computerName, string computerObjectId,
string computerDomain, bool isDomainController)
string computerDomain, bool isDomainController, TimeSpan timeout = default)
{
if (timeout == default) {
timeout = TimeSpan.FromMinutes(2);
}

//Open a handle to the server
var openServerResult = OpenSamServer(computerName);
var openServerResult = await Task.Run(() => OpenSamServer(computerName)).TimeoutAfter(timeout);
if (openServerResult.IsFailed)
{
_log.LogTrace("OpenServer failed on {ComputerName}: {Error}", computerName, openServerResult.SError);
Expand All @@ -71,9 +77,8 @@ await SendComputerStatus(new CSVComputerStatus

//Try to get the machine sid for the computer if its not already cached
SecurityIdentifier machineSid;
if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid))
{
var getMachineSidResult = server.GetMachineSid();
if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid)) {
var getMachineSidResult = await Task.Run(() => server.GetMachineSid()).TimeoutAfter(timeout);
if (getMachineSidResult.IsFailed)
{
_log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError);
Expand All @@ -97,7 +102,7 @@ await SendComputerStatus(new CSVComputerStatus
}

//Get all available domains in the server
var getDomainsResult = server.GetDomains();
var getDomainsResult = await Task.Run(() => server.GetDomains()).TimeoutAfter(timeout);
if (getDomainsResult.IsFailed)
{
_log.LogTrace("GetDomains failed on {ComputerName}: {Error}", computerName, getDomainsResult.SError);
Expand All @@ -118,7 +123,7 @@ await SendComputerStatus(new CSVComputerStatus
continue;

//Open a handle to the domain
var openDomainResult = server.OpenDomain(domainResult.Name);
var openDomainResult = await Task.Run(() => server.OpenDomain(domainResult.Name)).TimeoutAfter(timeout);
if (openDomainResult.IsFailed)
{
_log.LogTrace("Failed to open domain {Domain} on {ComputerName}: {Error}", domainResult.Name, computerName, openDomainResult.SError);
Expand All @@ -128,13 +133,16 @@ await SendComputerStatus(new CSVComputerStatus
ComputerName = computerName,
Status = openDomainResult.SError
});
if (openDomainResult.IsTimeout) {
yield break;
}
continue;
}

var domain = openDomainResult.Value;

//Open a handle to the available aliases
var getAliasesResult = domain.GetAliases();
var getAliasesResult = await Task.Run(() => domain.GetAliases()).TimeoutAfter(timeout);

if (getAliasesResult.IsFailed)
{
Expand All @@ -145,6 +153,10 @@ await SendComputerStatus(new CSVComputerStatus
ComputerName = computerName,
Status = getAliasesResult.SError
});

if (getAliasesResult.IsTimeout) {
yield break;
}
continue;
}

Expand All @@ -163,7 +175,7 @@ await SendComputerStatus(new CSVComputerStatus
};

//Open a handle to the alias
var openAliasResult = domain.OpenAlias(alias.Rid);
var openAliasResult = await Task.Run(() => domain.OpenAlias(alias.Rid)).TimeoutAfter(timeout);
if (openAliasResult.IsFailed)
{
_log.LogTrace("Failed to open alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error);
Expand All @@ -176,12 +188,15 @@ await SendComputerStatus(new CSVComputerStatus
ret.Collected = false;
ret.FailureReason = $"SamOpenAliasInDomain failed with status {openAliasResult.SError}";
yield return ret;
if (openAliasResult.IsTimeout) {
yield break;
}
continue;
}

var localGroup = openAliasResult.Value;
//Call GetMembersInAlias to get raw group members
var getMembersResult = localGroup.GetMembers();
var getMembersResult = await Task.Run(() => localGroup.GetMembers()).TimeoutAfter(timeout);
if (getMembersResult.IsFailed)
{
_log.LogTrace("Failed to get members in alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error);
Expand All @@ -194,6 +209,9 @@ await SendComputerStatus(new CSVComputerStatus
ret.Collected = false;
ret.FailureReason = $"SamGetMembersInAlias failed with status {getMembersResult.SError}";
yield return ret;
if (getMembersResult.IsTimeout) {
yield break;
}
continue;
}

Expand Down
19 changes: 14 additions & 5 deletions src/CommonLib/Processors/UserRightsAssignmentProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Security.Principal;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundRPC;
using SharpHoundRPC.Shared;
using SharpHoundRPC.Wrappers;

Expand Down Expand Up @@ -47,11 +49,15 @@ public IAsyncEnumerable<UserRightsAssignmentAPIResult> GetUserRightsAssignments(
/// <param name="computerDomain"></param>
/// <param name="isDomainController">Is the computer a domain controller</param>
/// <param name="desiredPrivileges"></param>
/// <param name="timeout"></param>
/// <returns></returns>
public async IAsyncEnumerable<UserRightsAssignmentAPIResult> GetUserRightsAssignments(string computerName,
string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null)
string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null, TimeSpan timeout = default)
{
var policyOpenResult = OpenLSAPolicy(computerName);
if (timeout == default) {
timeout = TimeSpan.FromMinutes(2);
}
var policyOpenResult = await Task.Run(() => OpenLSAPolicy(computerName)).TimeoutAfter(timeout);
if (!policyOpenResult.IsSuccess)
{
_log.LogDebug("LSAOpenPolicy failed on {ComputerName} with status {Status}", computerName,
Expand All @@ -71,7 +77,7 @@ await SendComputerStatus(new CSVComputerStatus
SecurityIdentifier machineSid;
if (!Cache.GetMachineSid(computerObjectId, out var temp))
{
var getMachineSidResult = server.GetLocalDomainInformation();
var getMachineSidResult = await Task.Run(() => server.GetLocalDomainInformation()).TimeoutAfter(timeout);
if (getMachineSidResult.IsFailed)
{
_log.LogWarning("Failed to get machine sid for {Server}: {Status}. Abandoning URA collection",
Expand Down Expand Up @@ -103,7 +109,7 @@ await SendComputerStatus(new CSVComputerStatus
};

//Ask for all principals with the specified privilege.
var enumerateAccountsResult = server.GetResolvedPrincipalsWithPrivilege(privilege);
var enumerateAccountsResult = await Task.Run(() => server.GetResolvedPrincipalsWithPrivilege(privilege)).TimeoutAfter(timeout);
if (enumerateAccountsResult.IsFailed)
{
_log.LogDebug(
Expand All @@ -118,6 +124,9 @@ await SendComputerStatus(new CSVComputerStatus
ret.FailureReason =
$"LSAEnumerateAccountsWithUserRights returned {enumerateAccountsResult.SError}";
yield return ret;
if (enumerateAccountsResult.IsTimeout) {
yield break;
}
continue;
}

Expand Down
34 changes: 34 additions & 0 deletions src/SharpHoundRPC/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System;
using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks;
using SharpHoundRPC.NetAPINative;

namespace SharpHoundRPC
{
Expand Down Expand Up @@ -32,5 +35,36 @@ public static byte[] GetBytes(this SecurityIdentifier identifier)
identifier.GetBinaryForm(bytes, 0);
return bytes;
}

public static async Task<Result<T>> TimeoutAfter<T>(this Task<Result<T>> task, TimeSpan timeout) {

using (var timeoutCancellationTokenSource = new CancellationTokenSource()) {

var completedTask = await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token));
if (completedTask == task) {
timeoutCancellationTokenSource.Cancel();
return await task; // Very important in order to propagate exceptions
}

var result = Result<T>.Fail("Timeout");
result.IsTimeout = true;
return result;
}
}

public static async Task<NetAPIResult<T>> TimeoutAfter<T>(this Task<NetAPIResult<T>> task, TimeSpan timeout) {

using (var timeoutCancellationTokenSource = new CancellationTokenSource()) {

var completedTask = await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token));
if (completedTask == task) {
timeoutCancellationTokenSource.Cancel();
return await task; // Very important in order to propagate exceptions
}

var result = NetAPIResult<T>.Fail("Timeout");
return result;
}
}
}
}
8 changes: 8 additions & 0 deletions src/SharpHoundRPC/NetAPINative/NetAPIResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,13 @@ public static implicit operator NetAPIResult<T>(string error)
{
return Fail(error);
}

public string GetErrorStatus() {
if (!string.IsNullOrEmpty(Error)) {
return Error;
}

return Status.ToString();
}
}
}
1 change: 1 addition & 0 deletions src/SharpHoundRPC/Result.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public class Result<T>
public T Value { get; private set; }
public string Error { get; private set; }
public bool IsFailed => !IsSuccess;
public bool IsTimeout { get; set; }

public static Result<T> Ok(T value)
{
Expand Down
35 changes: 35 additions & 0 deletions test/unit/ComputerSessionProcessorTest.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using CommonLibTest.Facades;
using Moq;
using Newtonsoft.Json;
using SharpHoundCommonLib;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundCommonLib.Processors;
using SharpHoundRPC;
using SharpHoundRPC.NetAPINative;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -227,5 +230,37 @@
Assert.Equal(2, test.Results.Length);
Assert.Equal(expected, test.Results);
}

[Fact]
public async Task ComputerSessionProcessor_TestTimeout() {
var nativeMethods = new Mock<NativeMethods>();
nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny<string>())).Callback(() => {
Thread.Sleep(200);
}).Returns(Array.Empty<NetSessionEnumResults>());
nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny<string>())).Callback(() => {
Thread.Sleep(200);
}).Returns(Array.Empty<NetWkstaUserEnumResults>());
var processor = new ComputerSessionProcessor(new MockLdapUtils(),"", nativeMethods.Object);
var receivedStatus = new List<CSVComputerStatus>();
var machineDomainSid = $"{Consts.MockDomainSid}-1000";
processor.ComputerStatusEvent += async status => {

Check warning on line 246 in test/unit/ComputerSessionProcessorTest.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 246 in test/unit/ComputerSessionProcessorTest.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
receivedStatus.Add(status);
};
var results = await processor.ReadUserSessions("primary.testlab.local", machineDomainSid, "testlab.local",
TimeSpan.FromMilliseconds(1));
Assert.Empty(results.Results);
Assert.Single(receivedStatus);
var status = receivedStatus[0];
Assert.Equal("Timeout", status.Status);

receivedStatus.Clear();

results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid, "testlab.local",
TimeSpan.FromMilliseconds(1));
Assert.Empty(results.Results);
Assert.Single(receivedStatus);
status = receivedStatus[0];
Assert.Equal("Timeout", status.Status);
}
}
}
Loading
Loading