Skip to content

Commit

Permalink
feat: add new timeoutafter methods and convert computer calls to use …
Browse files Browse the repository at this point in the history
…this appropriately
  • Loading branch information
rvazarkar committed Aug 15, 2024
1 parent d8e6757 commit 6bc07c3
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 46 deletions.
35 changes: 5 additions & 30 deletions src/CommonLib/Processors/ComputerSessionProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Win32;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundRPC;
using SharpHoundRPC.NetAPINative;

namespace SharpHoundCommonLib.Processors {
Expand Down Expand Up @@ -56,7 +57,7 @@ public async Task<SessionAPIResult> ReadUserSessions(string computerName, string

_log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName);

var apiTask = Task.Run(() => {
var result = await Task.Run(() => {
NetAPIResult<IEnumerable<NetSessionEnumResults>> result;
if (_doLocalAdminSessionEnum) {
// If we are authenticating using a local admin, we need to impersonate for this
Expand All @@ -77,20 +78,7 @@ public async Task<SessionAPIResult> ReadUserSessions(string computerName, string
}

return result;
});

if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) {
await SendComputerStatus(new CSVComputerStatus {
Status = "Timeout",
Task = "NetSessionEnum",
ComputerName = computerName
});
ret.Collected = false;
ret.FailureReason = "Timeout";
return ret;
}

var result = apiTask.Result;
}).TimeoutAfter(timeout);

if (result.IsFailed) {
await SendComputerStatus(new CSVComputerStatus {
Expand Down Expand Up @@ -186,7 +174,7 @@ public async Task<SessionAPIResult> ReadUserSessionsPrivileged(string computerNa

_log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName);

var apiTask = Task.Run(() => {
var result = await Task.Run(() => {
NetAPIResult<IEnumerable<NetWkstaUserEnumResults>>
result;
if (_doLocalAdminSessionEnum) {
Expand All @@ -208,21 +196,8 @@ public async Task<SessionAPIResult> ReadUserSessionsPrivileged(string computerNa
}

return result;
});
}).TimeoutAfter(timeout);

if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) {
await SendComputerStatus(new CSVComputerStatus {
Status = "Timeout",
Task = "NetWkstaUserEnum",
ComputerName = computerName
});
ret.Collected = false;
ret.FailureReason = "Timeout";
return ret;
}

var result = apiTask.Result;

if (result.IsFailed) {
await SendComputerStatus(new CSVComputerStatus {
Status = result.Status.ToString(),
Expand Down
40 changes: 29 additions & 11 deletions src/CommonLib/Processors/LocalGroupProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundRPC;
using SharpHoundRPC.Shared;
using SharpHoundRPC.Wrappers;

Expand Down Expand Up @@ -33,7 +34,7 @@ public virtual SharpHoundRPC.Result<ISAMServer> OpenSamServer(string computerNam
return SharpHoundRPC.Result<ISAMServer>.Fail(result.SError);
}

return SharpHoundRPC.Result<ISAMServer>.Ok(result.Value);
return SharpHoundRPC.Result<ISAMServer>.Ok(result.Value);
}

public IAsyncEnumerable<LocalGroupAPIResult> GetLocalGroups(ResolvedSearchResult result)
Expand All @@ -48,12 +49,17 @@ public IAsyncEnumerable<LocalGroupAPIResult> GetLocalGroups(ResolvedSearchResult
/// <param name="computerObjectId">The objectsid of the computer in the domain</param>
/// <param name="computerDomain">The domain the computer belongs too</param>
/// <param name="isDomainController">Is the computer a domain controller</param>
/// <param name="timeout"></param>
/// <returns></returns>
public async IAsyncEnumerable<LocalGroupAPIResult> GetLocalGroups(string computerName, string computerObjectId,
string computerDomain, bool isDomainController)
string computerDomain, bool isDomainController, TimeSpan timeout = default)
{
if (timeout == default) {
timeout = TimeSpan.FromMinutes(2);
}

//Open a handle to the server
var openServerResult = OpenSamServer(computerName);
var openServerResult = await Task.Run(() => OpenSamServer(computerName)).TimeoutAfter(timeout);
if (openServerResult.IsFailed)
{
_log.LogTrace("OpenServer failed on {ComputerName}: {Error}", computerName, openServerResult.SError);
Expand All @@ -71,9 +77,8 @@ await SendComputerStatus(new CSVComputerStatus

//Try to get the machine sid for the computer if its not already cached
SecurityIdentifier machineSid;
if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid))
{
var getMachineSidResult = server.GetMachineSid();
if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid)) {
var getMachineSidResult = await Task.Run(() => server.GetMachineSid()).TimeoutAfter(timeout);
if (getMachineSidResult.IsFailed)
{
_log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError);
Expand All @@ -97,7 +102,7 @@ await SendComputerStatus(new CSVComputerStatus
}

//Get all available domains in the server
var getDomainsResult = server.GetDomains();
var getDomainsResult = await Task.Run(() => server.GetDomains()).TimeoutAfter(timeout);
if (getDomainsResult.IsFailed)
{
_log.LogTrace("GetDomains failed on {ComputerName}: {Error}", computerName, getDomainsResult.SError);
Expand All @@ -118,7 +123,7 @@ await SendComputerStatus(new CSVComputerStatus
continue;

//Open a handle to the domain
var openDomainResult = server.OpenDomain(domainResult.Name);
var openDomainResult = await Task.Run(() => server.OpenDomain(domainResult.Name)).TimeoutAfter(timeout);
if (openDomainResult.IsFailed)
{
_log.LogTrace("Failed to open domain {Domain} on {ComputerName}: {Error}", domainResult.Name, computerName, openDomainResult.SError);
Expand All @@ -128,13 +133,16 @@ await SendComputerStatus(new CSVComputerStatus
ComputerName = computerName,
Status = openDomainResult.SError
});
if (openDomainResult.IsTimeout) {
yield break;
}
continue;
}

var domain = openDomainResult.Value;

//Open a handle to the available aliases
var getAliasesResult = domain.GetAliases();
var getAliasesResult = await Task.Run(() => domain.GetAliases()).TimeoutAfter(timeout);

if (getAliasesResult.IsFailed)
{
Expand All @@ -145,6 +153,10 @@ await SendComputerStatus(new CSVComputerStatus
ComputerName = computerName,
Status = getAliasesResult.SError
});

if (getAliasesResult.IsTimeout) {
yield break;
}
continue;
}

Expand All @@ -163,7 +175,7 @@ await SendComputerStatus(new CSVComputerStatus
};

//Open a handle to the alias
var openAliasResult = domain.OpenAlias(alias.Rid);
var openAliasResult = await Task.Run(() => domain.OpenAlias(alias.Rid)).TimeoutAfter(timeout);
if (openAliasResult.IsFailed)
{
_log.LogTrace("Failed to open alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error);
Expand All @@ -176,12 +188,15 @@ await SendComputerStatus(new CSVComputerStatus
ret.Collected = false;
ret.FailureReason = $"SamOpenAliasInDomain failed with status {openAliasResult.SError}";
yield return ret;
if (openAliasResult.IsTimeout) {
yield break;
}
continue;
}

var localGroup = openAliasResult.Value;
//Call GetMembersInAlias to get raw group members
var getMembersResult = localGroup.GetMembers();
var getMembersResult = await Task.Run(() => localGroup.GetMembers()).TimeoutAfter(timeout);
if (getMembersResult.IsFailed)
{
_log.LogTrace("Failed to get members in alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error);
Expand All @@ -194,6 +209,9 @@ await SendComputerStatus(new CSVComputerStatus
ret.Collected = false;
ret.FailureReason = $"SamGetMembersInAlias failed with status {getMembersResult.SError}";
yield return ret;
if (getMembersResult.IsTimeout) {
yield break;
}
continue;
}

Expand Down
19 changes: 14 additions & 5 deletions src/CommonLib/Processors/UserRightsAssignmentProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Security.Principal;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;
using SharpHoundCommonLib.OutputTypes;
using SharpHoundRPC;
using SharpHoundRPC.Shared;
using SharpHoundRPC.Wrappers;

Expand Down Expand Up @@ -47,11 +49,15 @@ public IAsyncEnumerable<UserRightsAssignmentAPIResult> GetUserRightsAssignments(
/// <param name="computerDomain"></param>
/// <param name="isDomainController">Is the computer a domain controller</param>
/// <param name="desiredPrivileges"></param>
/// <param name="timeout"></param>
/// <returns></returns>
public async IAsyncEnumerable<UserRightsAssignmentAPIResult> GetUserRightsAssignments(string computerName,
string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null)
string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null, TimeSpan timeout = default)
{
var policyOpenResult = OpenLSAPolicy(computerName);
if (timeout == default) {
timeout = TimeSpan.FromMinutes(2);
}
var policyOpenResult = await Task.Run(() => OpenLSAPolicy(computerName)).TimeoutAfter(timeout);
if (!policyOpenResult.IsSuccess)
{
_log.LogDebug("LSAOpenPolicy failed on {ComputerName} with status {Status}", computerName,
Expand All @@ -71,7 +77,7 @@ await SendComputerStatus(new CSVComputerStatus
SecurityIdentifier machineSid;
if (!Cache.GetMachineSid(computerObjectId, out var temp))
{
var getMachineSidResult = server.GetLocalDomainInformation();
var getMachineSidResult = await Task.Run(() => server.GetLocalDomainInformation()).TimeoutAfter(timeout);
if (getMachineSidResult.IsFailed)
{
_log.LogWarning("Failed to get machine sid for {Server}: {Status}. Abandoning URA collection",
Expand Down Expand Up @@ -103,7 +109,7 @@ await SendComputerStatus(new CSVComputerStatus
};

//Ask for all principals with the specified privilege.
var enumerateAccountsResult = server.GetResolvedPrincipalsWithPrivilege(privilege);
var enumerateAccountsResult = await Task.Run(() => server.GetResolvedPrincipalsWithPrivilege(privilege)).TimeoutAfter(timeout);
if (enumerateAccountsResult.IsFailed)
{
_log.LogDebug(
Expand All @@ -118,6 +124,9 @@ await SendComputerStatus(new CSVComputerStatus
ret.FailureReason =
$"LSAEnumerateAccountsWithUserRights returned {enumerateAccountsResult.SError}";
yield return ret;
if (enumerateAccountsResult.IsTimeout) {
yield break;
}
continue;
}

Expand Down
34 changes: 34 additions & 0 deletions src/SharpHoundRPC/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System;
using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks;
using SharpHoundRPC.NetAPINative;

namespace SharpHoundRPC
{
Expand Down Expand Up @@ -32,5 +35,36 @@ public static byte[] GetBytes(this SecurityIdentifier identifier)
identifier.GetBinaryForm(bytes, 0);
return bytes;
}

public static async Task<Result<T>> TimeoutAfter<T>(this Task<Result<T>> task, TimeSpan timeout) {

using (var timeoutCancellationTokenSource = new CancellationTokenSource()) {

var completedTask = await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token));
if (completedTask == task) {
timeoutCancellationTokenSource.Cancel();
return await task; // Very important in order to propagate exceptions
}

var result = Result<T>.Fail("Timeout");
result.IsTimeout = true;
return result;
}
}

public static async Task<NetAPIResult<T>> TimeoutAfter<T>(this Task<NetAPIResult<T>> task, TimeSpan timeout) {

using (var timeoutCancellationTokenSource = new CancellationTokenSource()) {

var completedTask = await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token));
if (completedTask == task) {
timeoutCancellationTokenSource.Cancel();
return await task; // Very important in order to propagate exceptions
}

var result = NetAPIResult<T>.Fail("Timeout");
return result;
}
}
}
}
1 change: 1 addition & 0 deletions src/SharpHoundRPC/Result.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public class Result<T>
public T Value { get; private set; }
public string Error { get; private set; }
public bool IsFailed => !IsSuccess;
public bool IsTimeout { get; set; }

public static Result<T> Ok(T value)
{
Expand Down

0 comments on commit 6bc07c3

Please sign in to comment.