From 1ebeebe313dccb034796216443ed6553052d495a Mon Sep 17 00:00:00 2001 From: rvazarkar Date: Tue, 27 Aug 2024 17:01:55 -0400 Subject: [PATCH 1/3] fix: temporarily disable semaphore until we can iron out deadlocks chore: some misc test fixes to fix flakiness chore: fix an async function that wasn't async --- src/CommonLib/ConnectionPoolManager.cs | 4 +- src/CommonLib/LdapConnectionPool.cs | 47 ++++-- src/CommonLib/LdapQueryParameters.cs | 13 +- .../Processors/ComputerAvailability.cs | 2 +- .../Processors/ComputerSessionProcessor.cs | 7 +- test/unit/ComputerSessionProcessorTest.cs | 136 +++++++----------- test/unit/LocalGroupProcessorTest.cs | 5 +- .../unit/UserRightsAssignmentProcessorTest.cs | 5 +- 8 files changed, 111 insertions(+), 108 deletions(-) diff --git a/src/CommonLib/ConnectionPoolManager.cs b/src/CommonLib/ConnectionPoolManager.cs index e6d5a6f1..5e001dec 100644 --- a/src/CommonLib/ConnectionPoolManager.cs +++ b/src/CommonLib/ConnectionPoolManager.cs @@ -99,13 +99,13 @@ private bool GetPool(string identifier, out LdapConnectionPool pool) { return await pool.GetConnectionAsync(); } - public async Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetLdapConnectionForServer( + public (bool Success, LdapConnectionWrapper connectionWrapper, string Message) GetLdapConnectionForServer( string identifier, string server, bool globalCatalog) { if (!GetPool(identifier, out var pool)) { return (false, default, $"Unable to resolve a pool for {identifier}"); } - return await pool.GetConnectionForSpecificServerAsync(server, globalCatalog); + return pool.GetConnectionForSpecificServerAsync(server, globalCatalog); } private string ResolveIdentifier(string identifier) { diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index a05d6407..490806cb 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -35,12 +35,13 @@ internal class LdapConnectionPool : IDisposable{ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { _connections = new ConcurrentBag(); _globalCatalogConnection = new ConcurrentBag(); - if (config.MaxConcurrentQueries > 0) { - _semaphore = new SemaphoreSlim(config.MaxConcurrentQueries, config.MaxConcurrentQueries); - } else { - //If MaxConcurrentQueries is 0, we'll just disable the semaphore entirely - _semaphore = null; - } + //TODO: Re-enable this once we track down the semaphore deadlock + // if (config.MaxConcurrentQueries > 0) { + // _semaphore = new SemaphoreSlim(config.MaxConcurrentQueries, config.MaxConcurrentQueries); + // } else { + // //If MaxConcurrentQueries is 0, we'll just disable the semaphore entirely + // _semaphore = null; + // } _identifier = identifier; _poolIdentifier = poolIdentifier; @@ -83,7 +84,9 @@ public async IAsyncEnumerable> Query(LdapQueryParam while (!cancellationToken.IsCancellationRequested) { //Grab our semaphore here to take one of our query slots if (_semaphore != null){ + _log.LogTrace("Query entering semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); await _semaphore.WaitAsync(cancellationToken); + _log.LogTrace("Query entered semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); } try { _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); @@ -159,7 +162,11 @@ public async IAsyncEnumerable> Query(LdapQueryParam queryParameters); } finally { // Always release our semaphore to prevent deadlocks - _semaphore?.Release(); + if (_semaphore != null) { + _log.LogTrace("Query releasing semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); + _semaphore.Release(); + _log.LogTrace("Query released semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); + } } //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function @@ -214,7 +221,9 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery while (!cancellationToken.IsCancellationRequested) { if (_semaphore != null){ + _log.LogTrace("PagedQuery entering semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); await _semaphore.WaitAsync(cancellationToken); + _log.LogTrace("PagedQuery entered semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); } SearchResponse response = null; try { @@ -255,7 +264,7 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery var backoffDelay = GetNextBackoff(retryCount); await Task.Delay(backoffDelay, cancellationToken); var (success, ldapConnectionWrapperNew, _) = - await GetConnectionForSpecificServerAsync(serverName, queryParameters.GlobalCatalog); + GetConnectionForSpecificServerAsync(serverName, queryParameters.GlobalCatalog); if (success) { _log.LogDebug("PagedQuery - Recovered from ServerDown successfully"); @@ -288,7 +297,11 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery LdapResult.Fail($"PagedQuery - Caught unrecoverable exception: {e.Message}", queryParameters); } finally { - _semaphore?.Release(); + if (_semaphore != null) { + _log.LogTrace("PagedQuery releasing semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); + _semaphore.Release(); + _log.LogTrace("PagedQuery released semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); + } } if (tempResult != null) { @@ -402,7 +415,9 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish while (!cancellationToken.IsCancellationRequested) { SearchResponse response = null; if (_semaphore != null){ + _log.LogTrace("RangedRetrieval entering semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); await _semaphore.WaitAsync(cancellationToken); + _log.LogTrace("RangedRetrieval entered semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); } try { response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); @@ -446,7 +461,11 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish tempResult = LdapResult.Fail($"Caught unrecoverable exception: {e.Message}", queryParameters); } finally { - _semaphore?.Release(); + if (_semaphore != null) { + _log.LogTrace("RangedRetrieval releasing semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); + _semaphore.Release(); + _log.LogTrace("RangedRetrieval released semaphore with {Count} remaining for query {Info}", _semaphore.CurrentCount, queryParameters.GetQueryInfo()); + } } //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function @@ -471,13 +490,17 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish step = entry.Attributes[currentRange].Count; } + //Release our connection before we iterate + if (complete) { + ReleaseConnection(connectionWrapper); + } + foreach (string dn in entry.Attributes[currentRange].GetValues(typeof(string))) { yield return Result.Ok(dn); index++; } if (complete) { - ReleaseConnection(connectionWrapper); yield break; } @@ -577,7 +600,7 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll return (true, connectionWrapper, null); } - public async Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> + public (bool Success, LdapConnectionWrapper connectionWrapper, string Message) GetConnectionForSpecificServerAsync(string server, bool globalCatalog) { return CreateNewConnectionForServer(server, globalCatalog); } diff --git a/src/CommonLib/LdapQueryParameters.cs b/src/CommonLib/LdapQueryParameters.cs index e4f099d9..354bf730 100644 --- a/src/CommonLib/LdapQueryParameters.cs +++ b/src/CommonLib/LdapQueryParameters.cs @@ -1,10 +1,11 @@ using System; using System.DirectoryServices.Protocols; +using System.Threading; using SharpHoundCommonLib.Enums; namespace SharpHoundCommonLib { - public class LdapQueryParameters - { + public class LdapQueryParameters { + private static int _queryIDIndex; private string _searchBase; private string _relativeSearchBase; public string LDAPFilter { get; set; } @@ -14,6 +15,12 @@ public class LdapQueryParameters public bool GlobalCatalog { get; set; } public bool IncludeSecurityDescriptor { get; set; } = false; public bool IncludeDeleted { get; set; } = false; + private int QueryID { get; } + + public LdapQueryParameters() { + QueryID = _queryIDIndex; + Interlocked.Increment(ref _queryIDIndex); + } public string SearchBase { get => _searchBase; @@ -35,7 +42,7 @@ public string RelativeSearchBase { public string GetQueryInfo() { - return $"Query Information - Filter: {LDAPFilter}, Domain: {DomainName}, GlobalCatalog: {GlobalCatalog}, ADSPath: {SearchBase}"; + return $"Query Information - Filter: {LDAPFilter}, Domain: {DomainName}, GlobalCatalog: {GlobalCatalog}, ADSPath: {SearchBase}, ID: {QueryID}"; } } } \ No newline at end of file diff --git a/src/CommonLib/Processors/ComputerAvailability.cs b/src/CommonLib/Processors/ComputerAvailability.cs index e87103a5..736365ab 100644 --- a/src/CommonLib/Processors/ComputerAvailability.cs +++ b/src/CommonLib/Processors/ComputerAvailability.cs @@ -16,7 +16,7 @@ public class ComputerAvailability private readonly bool _skipPasswordCheck; private readonly bool _skipPortScan; - public ComputerAvailability(int timeout = 500, int computerExpiryDays = 60, bool skipPortScan = false, + public ComputerAvailability(int timeout = 10000, int computerExpiryDays = 60, bool skipPortScan = false, bool skipPasswordCheck = false, ILogger log = null) { _scanner = new PortScanner(); diff --git a/src/CommonLib/Processors/ComputerSessionProcessor.cs b/src/CommonLib/Processors/ComputerSessionProcessor.cs index 1fbb746f..2c9a4673 100644 --- a/src/CommonLib/Processors/ComputerSessionProcessor.cs +++ b/src/CommonLib/Processors/ComputerSessionProcessor.cs @@ -24,8 +24,9 @@ public class ComputerSessionProcessor { private readonly string _localAdminUsername; private readonly string _localAdminPassword; - public ComputerSessionProcessor(ILdapUtils utils, string currentUserName = null, - NativeMethods nativeMethods = null, ILogger log = null, bool doLocalAdminSessionEnum = false, + public ComputerSessionProcessor(ILdapUtils utils, + NativeMethods nativeMethods = null, ILogger log = null, string currentUserName = null, + bool doLocalAdminSessionEnum = false, string localAdminUsername = null, string localAdminPassword = null) { _utils = utils; _nativeMethods = nativeMethods ?? new NativeMethods(); @@ -92,7 +93,7 @@ await SendComputerStatus(new CSVComputerStatus { return ret; } - _log.LogTrace("NetSessionEnum succeeded on {ComputerName}", computerName); + _log.LogDebug("NetSessionEnum succeeded on {ComputerName}", computerName); await SendComputerStatus(new CSVComputerStatus { Status = CSVComputerStatus.StatusSuccess, Task = "NetSessionEnum", diff --git a/test/unit/ComputerSessionProcessorTest.cs b/test/unit/ComputerSessionProcessorTest.cs index db8111b3..de364896 100644 --- a/test/unit/ComputerSessionProcessorTest.cs +++ b/test/unit/ComputerSessionProcessorTest.cs @@ -13,16 +13,13 @@ using Xunit; using Xunit.Abstractions; -namespace CommonLibTest -{ - public class ComputerSessionProcessorTest : IDisposable - { +namespace CommonLibTest { + public class ComputerSessionProcessorTest : IDisposable { private readonly string _computerDomain; private readonly string _computerSid; private readonly ITestOutputHelper _testOutputHelper; - public ComputerSessionProcessorTest(ITestOutputHelper testOutputHelper) - { + public ComputerSessionProcessorTest(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; _computerDomain = "TESTLAB.LOCAL"; _computerSid = "S-1-5-21-3130019616-2776909439-2417379446-1104"; @@ -30,172 +27,148 @@ public ComputerSessionProcessorTest(ITestOutputHelper testOutputHelper) #region IDispose Implementation - public void Dispose() - { + public void Dispose() { // Tear down (called once per test) } #endregion [Fact] - public async Task ComputerSessionProcessor_ReadUserSessions_FilteringWorks() - { + public async Task ComputerSessionProcessor_ReadUserSessions_FilteringWorks() { var mockNativeMethods = new Mock(); - var apiResult = new NetSessionEnumResults[] - { + var apiResult = new NetSessionEnumResults[] { new("dfm", "\\\\192.168.92.110"), new("admin", ""), new("admin", "\\\\192.168.92.110") }; mockNativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(apiResult); - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object,null, "dfm"); var result = await processor.ReadUserSessions("win10", _computerSid, _computerDomain); Assert.True(result.Collected); Assert.Empty(result.Results); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessions_ResolvesHost() - { + public async Task ComputerSessionProcessor_ReadUserSessions_ResolvesHost() { var mockNativeMethods = new Mock(); - var apiResult = new NetSessionEnumResults[] - { + var apiResult = new NetSessionEnumResults[] { new("admin", "\\\\192.168.1.1") }; mockNativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(apiResult); - var expected = new Session[] - { - new() - { + var expected = new Session[] { + new() { ComputerSID = "S-1-5-21-3130019616-2776909439-2417379446-1104", UserSID = "S-1-5-21-3130019616-2776909439-2417379446-2116" } }; - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object,null, "dfm"); var result = await processor.ReadUserSessions("win10", _computerSid, _computerDomain); Assert.True(result.Collected); Assert.Equal(expected, result.Results); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessions_ResolvesLocalHostEquivalent() - { + public async Task ComputerSessionProcessor_ReadUserSessions_ResolvesLocalHostEquivalent() { var mockNativeMethods = new Mock(); - var apiResult = new NetSessionEnumResults[] - { + var apiResult = new NetSessionEnumResults[] { new("admin", "\\\\127.0.0.1") }; mockNativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(apiResult); - var expected = new Session[] - { - new() - { + var expected = new Session[] { + new() { ComputerSID = _computerSid, UserSID = "S-1-5-21-3130019616-2776909439-2417379446-2116" } }; - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object,null, "dfm"); var result = await processor.ReadUserSessions("win10", _computerSid, _computerDomain); Assert.True(result.Collected); Assert.Equal(expected, result.Results); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessions_MultipleMatches_AddsAll() - { + public async Task ComputerSessionProcessor_ReadUserSessions_MultipleMatches_AddsAll() { var mockNativeMethods = new Mock(); - var apiResult = new NetSessionEnumResults[] - { + var apiResult = new NetSessionEnumResults[] { new("administrator", "\\\\127.0.0.1") }; mockNativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(apiResult); - var expected = new Session[] - { - new() - { + var expected = new Session[] { + new() { ComputerSID = _computerSid, UserSID = "S-1-5-21-3130019616-2776909439-2417379446-500" }, - new() - { + new() { ComputerSID = _computerSid, UserSID = "S-1-5-21-3084884204-958224920-2707782874-500" } }; - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object,null, "dfm"); var result = await processor.ReadUserSessions("win10", _computerSid, _computerDomain); Assert.True(result.Collected); Assert.Equal(expected, result.Results); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessions_NoGCMatch_TriesResolve() - { + public async Task ComputerSessionProcessor_ReadUserSessions_NoGCMatch_TriesResolve() { var mockNativeMethods = new Mock(); - var apiResult = new NetSessionEnumResults[] - { + var apiResult = new NetSessionEnumResults[] { new("test", "\\\\127.0.0.1") }; mockNativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(apiResult); - var expected = new Session[] - { - new() - { + var expected = new Session[] { + new() { ComputerSID = _computerSid, UserSID = "S-1-5-21-3130019616-2776909439-2417379446-1106" } }; - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object, null,"dfm"); var result = await processor.ReadUserSessions("win10", _computerSid, _computerDomain); Assert.True(result.Collected); Assert.Equal(expected, result.Results); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessions_ComputerAccessDenied_Handled() - { + public async Task ComputerSessionProcessor_ReadUserSessions_ComputerAccessDenied_Handled() { var mockNativeMethods = new Mock(); //mockNativeMethods.Setup(x => x.CallSamConnect(ref It.Ref.IsAny, out It.Ref.IsAny, It.IsAny(), ref It.Ref.IsAny)).Returns(NativeMethods.NtStatus.StatusAccessDenied); mockNativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())) .Returns(NetAPIEnums.NetAPIStatus.ErrorAccessDenied); - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object, null,"dfm"); var test = await processor.ReadUserSessions("test", "test", "test"); Assert.False(test.Collected); Assert.Equal(NetAPIEnums.NetAPIStatus.ErrorAccessDenied.ToString(), test.FailureReason); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_ComputerAccessDenied_ExceptionCaught() - { + public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_ComputerAccessDenied_ExceptionCaught() { var mockNativeMethods = new Mock(); //mockNativeMethods.Setup(x => x.CallSamConnect(ref It.Ref.IsAny, out It.Ref.IsAny, It.IsAny(), ref It.Ref.IsAny)).Returns(NativeMethods.NtStatus.StatusAccessDenied); mockNativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())) .Returns(NetAPIEnums.NetAPIStatus.ErrorAccessDenied); - var processor = new ComputerSessionProcessor(new MockLdapUtils(), "dfm", mockNativeMethods.Object); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), mockNativeMethods.Object, null,"dfm"); var test = await processor.ReadUserSessionsPrivileged("test", "test", "test"); Assert.False(test.Collected); Assert.Equal(NetAPIEnums.NetAPIStatus.ErrorAccessDenied.ToString(), test.FailureReason); } [Fact] - public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_FilteringWorks() - { + public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_FilteringWorks() { var mockNativeMethods = new Mock(); const string samAccountName = "WIN10"; //This is a sample response from a computer in a test environment. The duplicates are intentional - var apiResults = new NetWkstaUserEnumResults[] - { + var apiResults = new NetWkstaUserEnumResults[] { new("dfm", "TESTLAB"), new("Administrator", "PRIMARY"), new("Administrator", ""), @@ -209,21 +182,19 @@ public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_FilteringW }; mockNativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Returns(apiResults); - var expected = new Session[] - { - new() - { + var expected = new Session[] { + new() { ComputerSID = _computerSid, UserSID = "S-1-5-21-3130019616-2776909439-2417379446-1105" }, - new() - { + new() { ComputerSID = _computerSid, UserSID = "S-1-5-21-3130019616-2776909439-2417379446-500" } }; - var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods: mockNativeMethods.Object, currentUserName:"ADMINISTRATOR"); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods: mockNativeMethods.Object, + currentUserName: "ADMINISTRATOR"); var test = await processor.ReadUserSessionsPrivileged("WIN10.TESTLAB.LOCAL", samAccountName, _computerSid); Assert.True(test.Collected); _testOutputHelper.WriteLine(JsonConvert.SerializeObject(test.Results)); @@ -234,15 +205,14 @@ public async Task ComputerSessionProcessor_ReadUserSessionsPrivileged_FilteringW [Fact] public async Task ComputerSessionProcessor_TestTimeout() { var nativeMethods = new Mock(); - nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Callback(() => { + nativeMethods.Setup(x => x.NetSessionEnum(It.IsAny())).Returns(() => { Task.Delay(1000).Wait(); - }).Returns(Array.Empty()); - var processor = new ComputerSessionProcessor(new MockLdapUtils(),"", nativeMethods.Object); + return Array.Empty(); + }); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,""); var receivedStatus = new List(); var machineDomainSid = $"{Consts.MockDomainSid}-1000"; - processor.ComputerStatusEvent += async status => { - receivedStatus.Add(status); - }; + processor.ComputerStatusEvent += async status => { receivedStatus.Add(status); }; var results = await processor.ReadUserSessions("primary.testlab.local", machineDomainSid, "testlab.local", TimeSpan.FromMilliseconds(1)); Assert.Empty(results.Results); @@ -250,21 +220,21 @@ public async Task ComputerSessionProcessor_TestTimeout() { var status = receivedStatus[0]; Assert.Equal("Timeout", status.Status); } - + [Fact] public async Task ComputerSessionProcessor_TestTimeoutPrivileged() { var nativeMethods = new Mock(); - nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Callback(() => { + nativeMethods.Setup(x => x.NetWkstaUserEnum(It.IsAny())).Returns(() => { Task.Delay(1000).Wait(); - }).Returns(Array.Empty()); - var processor = new ComputerSessionProcessor(new MockLdapUtils(),"", nativeMethods.Object); + return Array.Empty(); + }); + var processor = new ComputerSessionProcessor(new MockLdapUtils(), nativeMethods.Object, null,""); var receivedStatus = new List(); var machineDomainSid = $"{Consts.MockDomainSid}-1000"; - processor.ComputerStatusEvent += async status => { - receivedStatus.Add(status); - }; - - var results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid, "testlab.local", + processor.ComputerStatusEvent += async status => { receivedStatus.Add(status); }; + + var results = await processor.ReadUserSessionsPrivileged("primary.testlab.local", machineDomainSid, + "testlab.local", TimeSpan.FromMilliseconds(1)); Assert.Empty(results.Results); Assert.Single(receivedStatus); diff --git a/test/unit/LocalGroupProcessorTest.cs b/test/unit/LocalGroupProcessorTest.cs index 460cb5a4..01d68d67 100644 --- a/test/unit/LocalGroupProcessorTest.cs +++ b/test/unit/LocalGroupProcessorTest.cs @@ -119,9 +119,10 @@ public async Task LocalGroupProcessor_TestTimeout() { var mockUtils = new Mock(); var mockProcessor = new Mock(mockUtils.Object, null); - mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Callback(() => { + mockProcessor.Setup(x => x.OpenSamServer(It.IsAny())).Returns(() => { Task.Delay(100).Wait(); - }).Returns(NtStatus.StatusAccessDenied); + return NtStatus.StatusAccessDenied; + }); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockDomainSid}-1000"; var receivedStatus = new List(); diff --git a/test/unit/UserRightsAssignmentProcessorTest.cs b/test/unit/UserRightsAssignmentProcessorTest.cs index 9b18481b..90a5a352 100644 --- a/test/unit/UserRightsAssignmentProcessorTest.cs +++ b/test/unit/UserRightsAssignmentProcessorTest.cs @@ -71,9 +71,10 @@ public async Task UserRightsAssignmentProcessor_TestDC() [Fact] public async Task UserRightsAssignmentProcessor_TestTimeout() { var mockProcessor = new Mock(new MockLdapUtils(), null); - mockProcessor.Setup(x => x.OpenLSAPolicy(It.IsAny())).Callback(() => { + mockProcessor.Setup(x => x.OpenLSAPolicy(It.IsAny())).Returns(()=> { Task.Delay(100).Wait(); - }).Returns(NtStatus.StatusAccessDenied); + return NtStatus.StatusAccessDenied; + }); var processor = mockProcessor.Object; var machineDomainSid = $"{Consts.MockDomainSid}-1000"; var receivedStatus = new List(); From fee9441f0c6bf3c0afdbfbb89461e99e029c9a76 Mon Sep 17 00:00:00 2001 From: rvazarkar Date: Thu, 29 Aug 2024 15:36:57 -0400 Subject: [PATCH 2/3] chore: bump version --- src/CommonLib/SharpHoundCommonLib.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CommonLib/SharpHoundCommonLib.csproj b/src/CommonLib/SharpHoundCommonLib.csproj index dae0d50b..d8e0b937 100644 --- a/src/CommonLib/SharpHoundCommonLib.csproj +++ b/src/CommonLib/SharpHoundCommonLib.csproj @@ -9,7 +9,7 @@ Common library for C# BloodHound enumeration tasks GPL-3.0-only https://github.com/BloodHoundAD/SharpHoundCommon - 4.0.5 + 4.0.6 SharpHoundCommonLib SharpHoundCommonLib From ba8fd2b6ec5a1170cf062f317a9ffad244c1b11e Mon Sep 17 00:00:00 2001 From: rvazarkar Date: Thu, 29 Aug 2024 15:46:05 -0400 Subject: [PATCH 3/3] chore: improve retry logging at debug level --- src/CommonLib/LdapConnectionPool.cs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index 490806cb..fe0cb5a7 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -114,6 +114,7 @@ public async IAsyncEnumerable> Query(LdapQueryParam * since non-paged queries do not require same server connections */ queryRetryCount++; + _log.LogDebug("Query - Attempting to recover from ServerDown for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), queryRetryCount); ReleaseConnection(connectionWrapper, true); for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { @@ -144,6 +145,7 @@ public async IAsyncEnumerable> Query(LdapQueryParam * The expectation is that given enough time, the server should stop being busy and service our query appropriately */ busyRetryCount++; + _log.LogDebug("Query - Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); var backoffDelay = GetNextBackoff(busyRetryCount); await Task.Delay(backoffDelay, cancellationToken); } catch (LdapException le) { @@ -258,7 +260,9 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery ReleaseConnection(connectionWrapper, true); yield break; } - + + _log.LogDebug("PagedQuery - Attempting to recover from ServerDown for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), queryRetryCount); + ReleaseConnection(connectionWrapper, true); for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { var backoffDelay = GetNextBackoff(retryCount); @@ -286,6 +290,7 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery * The expectation is that given enough time, the server should stop being busy and service our query appropriately */ busyRetryCount++; + _log.LogDebug("PagedQuery - Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); var backoffDelay = GetNextBackoff(busyRetryCount); await Task.Delay(backoffDelay, cancellationToken); } catch (LdapException le) { @@ -423,11 +428,13 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { busyRetryCount++; + _log.LogDebug("RangedRetrieval - Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); var backoffDelay = GetNextBackoff(busyRetryCount); await Task.Delay(backoffDelay, cancellationToken); } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && queryRetryCount < MaxRetries) { queryRetryCount++; + _log.LogDebug("RangedRetrieval - Attempting to recover from ServerDown for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), queryRetryCount); ReleaseConnection(connectionWrapper, true); for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { var backoffDelay = GetNextBackoff(retryCount);