diff --git a/src/CommonLib/Processors/ComputerSessionProcessor.cs b/src/CommonLib/Processors/ComputerSessionProcessor.cs index 71034549..ef695a4d 100644 --- a/src/CommonLib/Processors/ComputerSessionProcessor.cs +++ b/src/CommonLib/Processors/ComputerSessionProcessor.cs @@ -8,6 +8,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Win32; using SharpHoundCommonLib.OutputTypes; +using SharpHoundRPC; using SharpHoundRPC.NetAPINative; namespace SharpHoundCommonLib.Processors { @@ -56,7 +57,7 @@ public async Task ReadUserSessions(string computerName, string _log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName); - var apiTask = Task.Run(() => { + var result = await Task.Run(() => { NetAPIResult> result; if (_doLocalAdminSessionEnum) { // If we are authenticating using a local admin, we need to impersonate for this @@ -77,20 +78,7 @@ public async Task ReadUserSessions(string computerName, string } return result; - }); - - if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) { - await SendComputerStatus(new CSVComputerStatus { - Status = "Timeout", - Task = "NetSessionEnum", - ComputerName = computerName - }); - ret.Collected = false; - ret.FailureReason = "Timeout"; - return ret; - } - - var result = apiTask.Result; + }).TimeoutAfter(timeout); if (result.IsFailed) { await SendComputerStatus(new CSVComputerStatus { @@ -186,7 +174,7 @@ public async Task ReadUserSessionsPrivileged(string computerNa _log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName); - var apiTask = Task.Run(() => { + var result = await Task.Run(() => { NetAPIResult> result; if (_doLocalAdminSessionEnum) { @@ -208,21 +196,8 @@ public async Task ReadUserSessionsPrivileged(string computerNa } return result; - }); + }).TimeoutAfter(timeout); - if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) { - await SendComputerStatus(new CSVComputerStatus { - Status = "Timeout", - Task = "NetWkstaUserEnum", - ComputerName = computerName - }); - ret.Collected = false; - ret.FailureReason = "Timeout"; - return ret; - } - - var result = apiTask.Result; - if (result.IsFailed) { await SendComputerStatus(new CSVComputerStatus { Status = result.Status.ToString(), diff --git a/src/CommonLib/Processors/LocalGroupProcessor.cs b/src/CommonLib/Processors/LocalGroupProcessor.cs index a1c35090..6b66a526 100644 --- a/src/CommonLib/Processors/LocalGroupProcessor.cs +++ b/src/CommonLib/Processors/LocalGroupProcessor.cs @@ -6,6 +6,7 @@ using Microsoft.Extensions.Logging; using SharpHoundCommonLib.Enums; using SharpHoundCommonLib.OutputTypes; +using SharpHoundRPC; using SharpHoundRPC.Shared; using SharpHoundRPC.Wrappers; @@ -33,7 +34,7 @@ public virtual SharpHoundRPC.Result OpenSamServer(string computerNam return SharpHoundRPC.Result.Fail(result.SError); } - return SharpHoundRPC.Result.Ok(result.Value); + return SharpHoundRPC.Result.Ok(result.Value); } public IAsyncEnumerable GetLocalGroups(ResolvedSearchResult result) @@ -48,12 +49,17 @@ public IAsyncEnumerable GetLocalGroups(ResolvedSearchResult /// The objectsid of the computer in the domain /// The domain the computer belongs too /// Is the computer a domain controller + /// /// public async IAsyncEnumerable GetLocalGroups(string computerName, string computerObjectId, - string computerDomain, bool isDomainController) + string computerDomain, bool isDomainController, TimeSpan timeout = default) { + if (timeout == default) { + timeout = TimeSpan.FromMinutes(2); + } + //Open a handle to the server - var openServerResult = OpenSamServer(computerName); + var openServerResult = await Task.Run(() => OpenSamServer(computerName)).TimeoutAfter(timeout); if (openServerResult.IsFailed) { _log.LogTrace("OpenServer failed on {ComputerName}: {Error}", computerName, openServerResult.SError); @@ -71,9 +77,8 @@ await SendComputerStatus(new CSVComputerStatus //Try to get the machine sid for the computer if its not already cached SecurityIdentifier machineSid; - if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid)) - { - var getMachineSidResult = server.GetMachineSid(); + if (!Cache.GetMachineSid(computerObjectId, out var tempMachineSid)) { + var getMachineSidResult = await Task.Run(() => server.GetMachineSid()).TimeoutAfter(timeout); if (getMachineSidResult.IsFailed) { _log.LogTrace("GetMachineSid failed on {ComputerName}: {Error}", computerName, getMachineSidResult.SError); @@ -97,7 +102,7 @@ await SendComputerStatus(new CSVComputerStatus } //Get all available domains in the server - var getDomainsResult = server.GetDomains(); + var getDomainsResult = await Task.Run(() => server.GetDomains()).TimeoutAfter(timeout); if (getDomainsResult.IsFailed) { _log.LogTrace("GetDomains failed on {ComputerName}: {Error}", computerName, getDomainsResult.SError); @@ -118,7 +123,7 @@ await SendComputerStatus(new CSVComputerStatus continue; //Open a handle to the domain - var openDomainResult = server.OpenDomain(domainResult.Name); + var openDomainResult = await Task.Run(() => server.OpenDomain(domainResult.Name)).TimeoutAfter(timeout); if (openDomainResult.IsFailed) { _log.LogTrace("Failed to open domain {Domain} on {ComputerName}: {Error}", domainResult.Name, computerName, openDomainResult.SError); @@ -128,13 +133,16 @@ await SendComputerStatus(new CSVComputerStatus ComputerName = computerName, Status = openDomainResult.SError }); + if (openDomainResult.IsTimeout) { + yield break; + } continue; } var domain = openDomainResult.Value; //Open a handle to the available aliases - var getAliasesResult = domain.GetAliases(); + var getAliasesResult = await Task.Run(() => domain.GetAliases()).TimeoutAfter(timeout); if (getAliasesResult.IsFailed) { @@ -145,6 +153,10 @@ await SendComputerStatus(new CSVComputerStatus ComputerName = computerName, Status = getAliasesResult.SError }); + + if (getAliasesResult.IsTimeout) { + yield break; + } continue; } @@ -163,7 +175,7 @@ await SendComputerStatus(new CSVComputerStatus }; //Open a handle to the alias - var openAliasResult = domain.OpenAlias(alias.Rid); + var openAliasResult = await Task.Run(() => domain.OpenAlias(alias.Rid)).TimeoutAfter(timeout); if (openAliasResult.IsFailed) { _log.LogTrace("Failed to open alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error); @@ -176,12 +188,15 @@ await SendComputerStatus(new CSVComputerStatus ret.Collected = false; ret.FailureReason = $"SamOpenAliasInDomain failed with status {openAliasResult.SError}"; yield return ret; + if (openAliasResult.IsTimeout) { + yield break; + } continue; } var localGroup = openAliasResult.Value; //Call GetMembersInAlias to get raw group members - var getMembersResult = localGroup.GetMembers(); + var getMembersResult = await Task.Run(() => localGroup.GetMembers()).TimeoutAfter(timeout); if (getMembersResult.IsFailed) { _log.LogTrace("Failed to get members in alias {Alias} with RID {Rid} in domain {Domain} on computer {ComputerName}: {Error}", alias.Name, alias.Rid, domainResult.Name, computerName, openAliasResult.Error); @@ -194,6 +209,9 @@ await SendComputerStatus(new CSVComputerStatus ret.Collected = false; ret.FailureReason = $"SamGetMembersInAlias failed with status {getMembersResult.SError}"; yield return ret; + if (getMembersResult.IsTimeout) { + yield break; + } continue; } diff --git a/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs b/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs index 404427ee..b9b582fb 100644 --- a/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs +++ b/src/CommonLib/Processors/UserRightsAssignmentProcessor.cs @@ -1,9 +1,11 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Security.Principal; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using SharpHoundCommonLib.Enums; using SharpHoundCommonLib.OutputTypes; +using SharpHoundRPC; using SharpHoundRPC.Shared; using SharpHoundRPC.Wrappers; @@ -47,11 +49,15 @@ public IAsyncEnumerable GetUserRightsAssignments( /// /// Is the computer a domain controller /// + /// /// public async IAsyncEnumerable GetUserRightsAssignments(string computerName, - string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null) + string computerObjectId, string computerDomain, bool isDomainController, string[] desiredPrivileges = null, TimeSpan timeout = default) { - var policyOpenResult = OpenLSAPolicy(computerName); + if (timeout == default) { + timeout = TimeSpan.FromMinutes(2); + } + var policyOpenResult = await Task.Run(() => OpenLSAPolicy(computerName)).TimeoutAfter(timeout); if (!policyOpenResult.IsSuccess) { _log.LogDebug("LSAOpenPolicy failed on {ComputerName} with status {Status}", computerName, @@ -71,7 +77,7 @@ await SendComputerStatus(new CSVComputerStatus SecurityIdentifier machineSid; if (!Cache.GetMachineSid(computerObjectId, out var temp)) { - var getMachineSidResult = server.GetLocalDomainInformation(); + var getMachineSidResult = await Task.Run(() => server.GetLocalDomainInformation()).TimeoutAfter(timeout); if (getMachineSidResult.IsFailed) { _log.LogWarning("Failed to get machine sid for {Server}: {Status}. Abandoning URA collection", @@ -103,7 +109,7 @@ await SendComputerStatus(new CSVComputerStatus }; //Ask for all principals with the specified privilege. - var enumerateAccountsResult = server.GetResolvedPrincipalsWithPrivilege(privilege); + var enumerateAccountsResult = await Task.Run(() => server.GetResolvedPrincipalsWithPrivilege(privilege)).TimeoutAfter(timeout); if (enumerateAccountsResult.IsFailed) { _log.LogDebug( @@ -118,6 +124,9 @@ await SendComputerStatus(new CSVComputerStatus ret.FailureReason = $"LSAEnumerateAccountsWithUserRights returned {enumerateAccountsResult.SError}"; yield return ret; + if (enumerateAccountsResult.IsTimeout) { + yield break; + } continue; } diff --git a/src/SharpHoundRPC/Extensions.cs b/src/SharpHoundRPC/Extensions.cs index 50656bb1..98624d9b 100644 --- a/src/SharpHoundRPC/Extensions.cs +++ b/src/SharpHoundRPC/Extensions.cs @@ -1,5 +1,8 @@ using System; using System.Security.Principal; +using System.Threading; +using System.Threading.Tasks; +using SharpHoundRPC.NetAPINative; namespace SharpHoundRPC { @@ -32,5 +35,36 @@ public static byte[] GetBytes(this SecurityIdentifier identifier) identifier.GetBinaryForm(bytes, 0); return bytes; } + + public static async Task> TimeoutAfter(this Task> task, TimeSpan timeout) { + + using (var timeoutCancellationTokenSource = new CancellationTokenSource()) { + + var completedTask = await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token)); + if (completedTask == task) { + timeoutCancellationTokenSource.Cancel(); + return await task; // Very important in order to propagate exceptions + } + + var result = Result.Fail("Timeout"); + result.IsTimeout = true; + return result; + } + } + + public static async Task> TimeoutAfter(this Task> task, TimeSpan timeout) { + + using (var timeoutCancellationTokenSource = new CancellationTokenSource()) { + + var completedTask = await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token)); + if (completedTask == task) { + timeoutCancellationTokenSource.Cancel(); + return await task; // Very important in order to propagate exceptions + } + + var result = NetAPIResult.Fail("Timeout"); + return result; + } + } } } \ No newline at end of file diff --git a/src/SharpHoundRPC/Result.cs b/src/SharpHoundRPC/Result.cs index 800c2135..69ee0d39 100644 --- a/src/SharpHoundRPC/Result.cs +++ b/src/SharpHoundRPC/Result.cs @@ -7,6 +7,7 @@ public class Result public T Value { get; private set; } public string Error { get; private set; } public bool IsFailed => !IsSuccess; + public bool IsTimeout { get; set; } public static Result Ok(T value) {