diff --git a/src/CommonLib/ConnectionPoolManager.cs b/src/CommonLib/ConnectionPoolManager.cs index f7f015f8..e6d5a6f1 100644 --- a/src/CommonLib/ConnectionPoolManager.cs +++ b/src/CommonLib/ConnectionPoolManager.cs @@ -1,13 +1,15 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.DirectoryServices; using System.Security.Principal; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using SharpHoundCommonLib.Processors; namespace SharpHoundCommonLib { - public class ConnectionPoolManager : IDisposable{ + internal class ConnectionPoolManager : IDisposable{ private readonly ConcurrentDictionary _pools = new(); private readonly LdapConfig _ldapConfig; private readonly string[] _translateNames = { "Administrator", "admin" }; @@ -21,6 +23,35 @@ public ConnectionPoolManager(LdapConfig config, ILogger log = null, PortScanner _portScanner = scanner ?? new PortScanner(); } + public IAsyncEnumerable> RangedRetrieval(string distinguishedName, + string attributeName, CancellationToken cancellationToken = new()) { + var domain = Helpers.DistinguishedNameToDomain(distinguishedName); + + if (!GetPool(domain, out var pool)) { + return new List> {Result.Fail("Failed to resolve a connection pool")}.ToAsyncEnumerable(); + } + + return pool.RangedRetrieval(distinguishedName, attributeName, cancellationToken); + } + + public IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, + CancellationToken cancellationToken = new()) { + if (!GetPool(queryParameters.DomainName, out var pool)) { + return new List> {LdapResult.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable(); + } + + return pool.PagedQuery(queryParameters, cancellationToken); + } + + public IAsyncEnumerable> Query(LdapQueryParameters queryParameters, + CancellationToken cancellationToken = new()) { + if (!GetPool(queryParameters.DomainName, out var pool)) { + return new List> {LdapResult.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable(); + } + + return pool.Query(queryParameters, cancellationToken); + } + public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) { if (connectionWrapper == null) { return; @@ -41,18 +72,27 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn return (success, message); } - public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection( - string identifier, bool globalCatalog) { + private bool GetPool(string identifier, out LdapConnectionPool pool) { if (identifier == null) { - return (false, default, "Provided a null identifier for the connection"); + pool = default; + return false; } - var resolved = ResolveIdentifier(identifier); - if (!_pools.TryGetValue(resolved, out var pool)) { + var resolved = ResolveIdentifier(identifier); + if (!_pools.TryGetValue(resolved, out pool)) { pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner); _pools.TryAdd(resolved, pool); } + return true; + } + + public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection( + string identifier, bool globalCatalog) { + if (!GetPool(identifier, out var pool)) { + return (false, default, $"Unable to resolve a pool for {identifier}"); + } + if (globalCatalog) { return await pool.GetGlobalCatalogConnectionAsync(); } @@ -61,11 +101,8 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn public async Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetLdapConnectionForServer( string identifier, string server, bool globalCatalog) { - var resolved = ResolveIdentifier(identifier); - - if (!_pools.TryGetValue(resolved, out var pool)) { - pool = new LdapConnectionPool(resolved, identifier, _ldapConfig,scanner: _portScanner); - _pools.TryAdd(resolved, pool); + if (!GetPool(identifier, out var pool)) { + return (false, default, $"Unable to resolve a pool for {identifier}"); } return await pool.GetConnectionForSpecificServerAsync(server, globalCatalog); diff --git a/src/CommonLib/Extensions.cs b/src/CommonLib/Extensions.cs index c7086a62..be102fd9 100644 --- a/src/CommonLib/Extensions.cs +++ b/src/CommonLib/Extensions.cs @@ -121,6 +121,47 @@ public async ValueTask MoveNextAsync() { public T Current => _current; } + + internal static IAsyncEnumerable ToAsyncEnumerable(this IEnumerable source) { + return source switch { + ICollection collection => new IAsyncEnumerableCollectionAdapter(collection), + _ => null + }; + } + + private sealed class IAsyncEnumerableCollectionAdapter : IAsyncEnumerable { + private readonly IAsyncEnumerator _enumerator; + + public IAsyncEnumerableCollectionAdapter(ICollection source) { + _enumerator = new IAsyncEnumeratorCollectionAdapter(source); + } + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken()) { + return _enumerator; + } + } + + private sealed class IAsyncEnumeratorCollectionAdapter : IAsyncEnumerator { + private readonly IEnumerable _source; + private IEnumerator _enumerator; + + public IAsyncEnumeratorCollectionAdapter(ICollection source) { + _source = source; + } + + public ValueTask DisposeAsync() { + _enumerator = null; + return new ValueTask(Task.CompletedTask); + } + + public ValueTask MoveNextAsync() { + if (_enumerator == null) { + _enumerator = _source.GetEnumerator(); + } + return new ValueTask(_enumerator.MoveNext()); + } + + public T Current => _enumerator.Current; + } public static string LdapValue(this SecurityIdentifier s) diff --git a/src/CommonLib/LdapConfig.cs b/src/CommonLib/LdapConfig.cs index e59bae71..69c9a17e 100644 --- a/src/CommonLib/LdapConfig.cs +++ b/src/CommonLib/LdapConfig.cs @@ -1,4 +1,5 @@ using System.DirectoryServices.Protocols; +using System.Text; namespace SharpHoundCommonLib { @@ -13,6 +14,7 @@ public class LdapConfig public bool DisableSigning { get; set; } = false; public bool DisableCertVerification { get; set; } = false; public AuthType AuthType { get; set; } = AuthType.Kerberos; + public int MaxConcurrentQueries { get; set; } = 15; //Returns the port for connecting to LDAP. Will always respect a user's overridden config over anything else public int GetPort(bool ssl) @@ -32,5 +34,23 @@ public int GetGCPort(bool ssl) { return ssl ? 3269 : 3268; } + + public override string ToString() { + var sb = new StringBuilder(); + sb.AppendLine($"Server: {Server}"); + sb.AppendLine($"Port: {Port}"); + sb.AppendLine($"SSLPort: {GetPort(true)}"); + sb.AppendLine($"ForceSSL: {GetPort(false)}"); + sb.AppendLine($"AuthType: {AuthType.ToString()}"); + sb.AppendLine($"MaxConcurrentQueries: {MaxConcurrentQueries}"); + if (!string.IsNullOrWhiteSpace(Username)) { + sb.AppendLine($"Username: {Username}"); + } + + if (!string.IsNullOrWhiteSpace(Password)) { + sb.AppendLine($"Password: {new string('*', Password.Length)}"); + } + return sb.ToString(); + } } } \ No newline at end of file diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index 69f44da7..a05d6407 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -1,8 +1,11 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.DirectoryServices.ActiveDirectory; using System.DirectoryServices.Protocols; +using System.Linq; using System.Net; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -13,7 +16,7 @@ using SharpHoundRPC.NetAPINative; namespace SharpHoundCommonLib { - public class LdapConnectionPool : IDisposable{ + internal class LdapConnectionPool : IDisposable{ private readonly ConcurrentBag _connections; private readonly ConcurrentBag _globalCatalogConnection; private readonly SemaphoreSlim _semaphore; @@ -23,11 +26,22 @@ public class LdapConnectionPool : IDisposable{ private readonly ILogger _log; private readonly PortScanner _portScanner; private readonly NativeMethods _nativeMethods; + private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2); + private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20); + private const int BackoffDelayMultiplier = 2; + private const int MaxRetries = 3; + private static readonly ConcurrentDictionary DCInfoCache = new(); - public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, int maxConnections = 10, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { + public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { _connections = new ConcurrentBag(); _globalCatalogConnection = new ConcurrentBag(); - _semaphore = new SemaphoreSlim(maxConnections, maxConnections); + if (config.MaxConcurrentQueries > 0) { + _semaphore = new SemaphoreSlim(config.MaxConcurrentQueries, config.MaxConcurrentQueries); + } else { + //If MaxConcurrentQueries is 0, we'll just disable the semaphore entirely + _semaphore = null; + } + _identifier = identifier; _poolIdentifier = poolIdentifier; _ldapConfig = config; @@ -35,14 +49,525 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c _portScanner = scanner ?? new PortScanner(); _nativeMethods = nativeMethods ?? new NativeMethods(); } + + private async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(bool globalCatalog) { + if (globalCatalog) { + return await GetGlobalCatalogConnectionAsync(); + } + return await GetConnectionAsync(); + } + + public async IAsyncEnumerable> Query(LdapQueryParameters queryParameters, + [EnumeratorCancellation] CancellationToken cancellationToken = new()) { + var setupResult = await SetupLdapQuery(queryParameters); + + if (!setupResult.Success) { + _log.LogInformation("Query - Failure during query setup: {Reason}\n{Info}", setupResult.Message, + queryParameters.GetQueryInfo()); + yield break; + } + + var searchRequest = setupResult.SearchRequest; + var connectionWrapper = setupResult.ConnectionWrapper; + + if (cancellationToken.IsCancellationRequested) { + ReleaseConnection(connectionWrapper); + yield break; + } + + var queryRetryCount = 0; + var busyRetryCount = 0; + LdapResult tempResult = null; + var querySuccess = false; + SearchResponse response = null; + while (!cancellationToken.IsCancellationRequested) { + //Grab our semaphore here to take one of our query slots + if (_semaphore != null){ + await _semaphore.WaitAsync(cancellationToken); + } + try { + _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); + response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + + if (response != null) { + querySuccess = true; + } else if (queryRetryCount == MaxRetries) { + tempResult = + LdapResult.Fail($"Failed to get a response after {MaxRetries} attempts", + queryParameters); + } else { + queryRetryCount++; + } + } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && + queryRetryCount < MaxRetries) { + /* + * A ServerDown exception indicates that our connection is no longer valid for one of many reasons. + * We'll want to release our connection back to the pool, but dispose it. We need a new connection, + * and because this is not a paged query, we can get this connection from anywhere. + * + * We use queryRetryCount here to prevent an infinite retry loop from occurring + * + * Release our connection in a faulted state since the connection is defunct. Attempt to get a new connection to any server in the domain + * since non-paged queries do not require same server connections + */ + queryRetryCount++; + ReleaseConnection(connectionWrapper, true); + + for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { + var backoffDelay = GetNextBackoff(retryCount); + await Task.Delay(backoffDelay, cancellationToken); + var (success, newConnectionWrapper, _) = + await GetLdapConnection(queryParameters.GlobalCatalog); + if (success) { + _log.LogDebug( + "Query - Recovered from ServerDown successfully, connection made to {NewServer}", + newConnectionWrapper.GetServer()); + connectionWrapper = newConnectionWrapper; + break; + } + + //If we hit our max retries for making a new connection, set tempResult so we can yield it after this logic + if (retryCount == MaxRetries - 1) { + _log.LogError("Query - Failed to get a new connection after ServerDown.\n{Info}", + queryParameters.GetQueryInfo()); + tempResult = + LdapResult.Fail( + "Query - Failed to get a new connection after ServerDown.", queryParameters); + } + } + } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { + /* + * If we get a busy error, we want to do an exponential backoff, but maintain the current connection + * The expectation is that given enough time, the server should stop being busy and service our query appropriately + */ + busyRetryCount++; + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } catch (LdapException le) { + /* + * This is our fallback catch. If our retry counts have been exhausted this will trigger and break us out of our loop + */ + tempResult = LdapResult.Fail( + $"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", + queryParameters); + } catch (Exception e) { + /* + * Generic exception handling for unforeseen circumstances + */ + tempResult = + LdapResult.Fail($"Query - Caught unrecoverable exception: {e.Message}", + queryParameters); + } finally { + // Always release our semaphore to prevent deadlocks + _semaphore?.Release(); + } + + //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function + if (tempResult != null) { + if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { + ReleaseConnection(connectionWrapper, true); + } else { + ReleaseConnection(connectionWrapper); + } + + yield return tempResult; + yield break; + } + + //If we've successfully made our query, break out of the while loop + if (querySuccess) { + break; + } + } + + ReleaseConnection(connectionWrapper); + foreach (SearchResultEntry entry in response.Entries) { + yield return LdapResult.Ok(new SearchResultEntryWrapper(entry)); + } + } + + public async IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, + [EnumeratorCancellation] CancellationToken cancellationToken = new()) { + var setupResult = await SetupLdapQuery(queryParameters); + + if (!setupResult.Success) { + _log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message, + queryParameters.GetQueryInfo()); + yield break; + } + + var searchRequest = setupResult.SearchRequest; + var connectionWrapper = setupResult.ConnectionWrapper; + var serverName = setupResult.Server; + + if (serverName == null) { + _log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible"); + } + + var pageControl = new PageResultRequestControl(500); + searchRequest.Controls.Add(pageControl); + + PageResultResponseControl pageResponse = null; + var busyRetryCount = 0; + var queryRetryCount = 0; + LdapResult tempResult = null; + + while (!cancellationToken.IsCancellationRequested) { + if (_semaphore != null){ + await _semaphore.WaitAsync(cancellationToken); + } + SearchResponse response = null; + try { + _log.LogTrace("Sending paged ldap request - {Info}", queryParameters.GetQueryInfo()); + response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + if (response != null) { + pageResponse = (PageResultResponseControl)response.Controls + .Where(x => x is PageResultResponseControl).DefaultIfEmpty(null).FirstOrDefault(); + queryRetryCount = 0; + } else if (queryRetryCount == MaxRetries) { + tempResult = LdapResult.Fail( + $"PagedQuery - Failed to get a response after {MaxRetries} attempts", + queryParameters); + } else { + queryRetryCount++; + } + } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) { + /* + * A ServerDown exception indicates that our connection is no longer valid for one of many reasons. + * We'll want to release our connection back to the pool, but dispose it. We need a new connection, + * and because this is not a paged query, we can get this connection from anywhere. + * + * We use queryRetryCount here to prevent an infinite retry loop from occurring + * + * Release our connection in a faulted state since the connection is defunct. + * Paged queries require a connection to be made to the same server which we started the paged query on + */ + if (serverName == null) { + _log.LogError( + "PagedQuery - Received server down exception without a known servername. Unable to generate new connection\n{Info}", + queryParameters.GetQueryInfo()); + ReleaseConnection(connectionWrapper, true); + yield break; + } + + ReleaseConnection(connectionWrapper, true); + for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { + var backoffDelay = GetNextBackoff(retryCount); + await Task.Delay(backoffDelay, cancellationToken); + var (success, ldapConnectionWrapperNew, _) = + await GetConnectionForSpecificServerAsync(serverName, queryParameters.GlobalCatalog); + + if (success) { + _log.LogDebug("PagedQuery - Recovered from ServerDown successfully"); + connectionWrapper = ldapConnectionWrapperNew; + break; + } + + if (retryCount == MaxRetries - 1) { + _log.LogError("PagedQuery - Failed to get a new connection after ServerDown.\n{Info}", + queryParameters.GetQueryInfo()); + tempResult = + LdapResult.Fail("Failed to get a new connection after serverdown", + queryParameters, le.ErrorCode); + } + } + } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { + /* + * If we get a busy error, we want to do an exponential backoff, but maintain the current connection + * The expectation is that given enough time, the server should stop being busy and service our query appropriately + */ + busyRetryCount++; + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } catch (LdapException le) { + tempResult = LdapResult.Fail( + $"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", + queryParameters, le.ErrorCode); + } catch (Exception e) { + tempResult = + LdapResult.Fail($"PagedQuery - Caught unrecoverable exception: {e.Message}", + queryParameters); + } finally { + _semaphore?.Release(); + } + + if (tempResult != null) { + if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { + ReleaseConnection(connectionWrapper, true); + } else { + ReleaseConnection(connectionWrapper); + } + + yield return tempResult; + yield break; + } + + if (cancellationToken.IsCancellationRequested) { + ReleaseConnection(connectionWrapper); + yield break; + } + + //I'm not sure why this happens sometimes, but if we try the request again, it works sometimes, other times we get an exception + if (response == null || pageResponse == null) { + continue; + } + + foreach (SearchResultEntry entry in response.Entries) { + if (cancellationToken.IsCancellationRequested) { + ReleaseConnection(connectionWrapper); + yield break; + } + + yield return LdapResult.Ok(new SearchResultEntryWrapper(entry)); + } + + if (pageResponse.Cookie.Length == 0 || response.Entries.Count == 0 || + cancellationToken.IsCancellationRequested) { + ReleaseConnection(connectionWrapper); + yield break; + } + + pageControl.Cookie = pageResponse.Cookie; + } + } + + private async Task SetupLdapQuery(LdapQueryParameters queryParameters) { + var result = new LdapQuerySetupResult(); + var (success, connectionWrapper, message) = + await GetLdapConnection(queryParameters.GlobalCatalog); + if (!success) { + result.Success = false; + result.Message = $"Unable to create a connection: {message}"; + return result; + } + + //This should never happen as far as I know, so just checking for safety + if (connectionWrapper.Connection == null) { + result.Success = false; + result.Message = "Connection object is null"; + return result; + } + + if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) { + result.Success = false; + result.Message = "Failed to create search request"; + ReleaseConnection(connectionWrapper); + return result; + } + + result.Server = connectionWrapper.GetServer(); + result.Success = true; + result.SearchRequest = searchRequest; + result.ConnectionWrapper = connectionWrapper; + return result; + } + + public async IAsyncEnumerable> RangedRetrieval(string distinguishedName, + string attributeName, [EnumeratorCancellation] CancellationToken cancellationToken = new()) { + var domain = Helpers.DistinguishedNameToDomain(distinguishedName); + + var connectionResult = await GetConnectionAsync(); + if (!connectionResult.Success) { + yield return Result.Fail(connectionResult.Message); + yield break; + } + + var index = 0; + var step = 0; + + //Start by using * as our upper index, which will automatically give us the range size + var currentRange = $"{attributeName};range={index}-*"; + var complete = false; + + var queryParameters = new LdapQueryParameters { + DomainName = domain, + LDAPFilter = $"{attributeName}=*", + Attributes = new[] { currentRange }, + SearchScope = SearchScope.Base, + SearchBase = distinguishedName + }; + var connectionWrapper = connectionResult.ConnectionWrapper; + + if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) { + ReleaseConnection(connectionWrapper); + yield return Result.Fail("Failed to create search request"); + yield break; + } + + var queryRetryCount = 0; + var busyRetryCount = 0; + + LdapResult tempResult = null; + + while (!cancellationToken.IsCancellationRequested) { + SearchResponse response = null; + if (_semaphore != null){ + await _semaphore.WaitAsync(cancellationToken); + } + try { + response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { + busyRetryCount++; + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && + queryRetryCount < MaxRetries) { + queryRetryCount++; + ReleaseConnection(connectionWrapper, true); + for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { + var backoffDelay = GetNextBackoff(retryCount); + await Task.Delay(backoffDelay, cancellationToken); + var (success, newConnectionWrapper, message) = + await GetLdapConnection(false); + if (success) { + _log.LogDebug( + "RangedRetrieval - Recovered from ServerDown successfully, connection made to {NewServer}", + newConnectionWrapper.GetServer()); + connectionWrapper = newConnectionWrapper; + break; + } + + //If we hit our max retries for making a new connection, set tempResult so we can yield it after this logic + if (retryCount == MaxRetries - 1) { + _log.LogError( + "RangedRetrieval - Failed to get a new connection after ServerDown for path {Path}", + distinguishedName); + tempResult = + LdapResult.Fail( + "RangedRetrieval - Failed to get a new connection after ServerDown.", + queryParameters, le.ErrorCode); + } + } + } catch (LdapException le) { + tempResult = LdapResult.Fail( + $"Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", + queryParameters, le.ErrorCode); + } catch (Exception e) { + tempResult = + LdapResult.Fail($"Caught unrecoverable exception: {e.Message}", queryParameters); + } finally { + _semaphore?.Release(); + } + + //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function + //We handle connection release in the relevant exception blocks + if (tempResult != null) { + if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { + ReleaseConnection(connectionWrapper, true); + } else { + ReleaseConnection(connectionWrapper); + } + + yield return tempResult; + yield break; + } + + if (response?.Entries.Count == 1) { + var entry = response.Entries[0]; + //We dont know the name of our attribute, but there should only be one, so we're safe to just use a loop here + foreach (string attr in entry.Attributes.AttributeNames) { + currentRange = attr; + complete = currentRange.IndexOf("*", 0, StringComparison.OrdinalIgnoreCase) > 0; + step = entry.Attributes[currentRange].Count; + } + + foreach (string dn in entry.Attributes[currentRange].GetValues(typeof(string))) { + yield return Result.Ok(dn); + index++; + } + + if (complete) { + ReleaseConnection(connectionWrapper); + yield break; + } + + currentRange = $"{attributeName};range={index}-{index + step}"; + searchRequest.Attributes.Clear(); + searchRequest.Attributes.Add(currentRange); + } else { + //I dont know what can cause a RR to have multiple entries, but its nothing good. Break out + ReleaseConnection(connectionWrapper); + yield break; + } + } + + ReleaseConnection(connectionWrapper); + } + + private static TimeSpan GetNextBackoff(int retryCount) { + return TimeSpan.FromSeconds(Math.Min( + MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount), + MaxBackoffDelay.TotalSeconds)); + } + + private bool CreateSearchRequest(LdapQueryParameters queryParameters, + LdapConnectionWrapper connectionWrapper, out SearchRequest searchRequest) { + string basePath; + if (!string.IsNullOrWhiteSpace(queryParameters.SearchBase)) { + basePath = queryParameters.SearchBase; + } else if (!connectionWrapper.GetSearchBase(queryParameters.NamingContext, out basePath)) { + string tempPath; + if (CallDsGetDcName(queryParameters.DomainName, out var info) && info != null) { + tempPath = Helpers.DomainNameToDistinguishedName(info.Value.DomainName); + connectionWrapper.SaveContext(queryParameters.NamingContext, basePath); + } else if (LdapUtils.GetDomain(queryParameters.DomainName,_ldapConfig, out var domainObject)) { + tempPath = Helpers.DomainNameToDistinguishedName(domainObject.Name); + } else { + searchRequest = null; + return false; + } + + basePath = queryParameters.NamingContext switch { + NamingContext.Configuration => $"CN=Configuration,{tempPath}", + NamingContext.Schema => $"CN=Schema,CN=Configuration,{tempPath}", + NamingContext.Default => tempPath, + _ => throw new ArgumentOutOfRangeException() + }; + + connectionWrapper.SaveContext(queryParameters.NamingContext, basePath); + } + + if (string.IsNullOrWhiteSpace(queryParameters.SearchBase) && !string.IsNullOrWhiteSpace(queryParameters.RelativeSearchBase)) { + basePath = $"{queryParameters.RelativeSearchBase},{basePath}"; + } + + searchRequest = new SearchRequest(basePath, queryParameters.LDAPFilter, queryParameters.SearchScope, + queryParameters.Attributes); + searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope)); + if (queryParameters.IncludeDeleted) { + searchRequest.Controls.Add(new ShowDeletedControl()); + } + + if (queryParameters.IncludeSecurityDescriptor) { + searchRequest.Controls.Add(new SecurityDescriptorFlagControl { + SecurityMasks = SecurityMasks.Dacl | SecurityMasks.Owner + }); + } + + return true; + } + + private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControllerInfo? info) { + if (DCInfoCache.TryGetValue(domainName.ToUpper().Trim(), out info)) return info != null; + + var apiResult = _nativeMethods.CallDsGetDcName(null, domainName, + (uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY | + NetAPIEnums.DSGETDCNAME_FLAGS.DS_RETURN_DNS_NAME | + NetAPIEnums.DSGETDCNAME_FLAGS.DS_DIRECTORY_SERVICE_REQUIRED)); + + if (apiResult.IsFailed) { + DCInfoCache.TryAdd(domainName.ToUpper().Trim(), null); + return false; + } + + info = apiResult.Value; + return true; + } public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetConnectionAsync() { - await _semaphore.WaitAsync(); if (!_connections.TryTake(out var connectionWrapper)) { var (success, connection, message) = await CreateNewConnection(); if (!success) { - //If we didn't get a connection, immediately release the semaphore so we don't have hanging ones - _semaphore.Release(); return (false, null, message); } @@ -54,24 +579,14 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c public async Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetConnectionForSpecificServerAsync(string server, bool globalCatalog) { - await _semaphore.WaitAsync(); - - var result= CreateNewConnectionForServer(server, globalCatalog); - if (!result.Success) { - //If we didn't get a connection, immediately release the semaphore so we don't have hanging ones - _semaphore.Release(); - } - - return result; + return CreateNewConnectionForServer(server, globalCatalog); } public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetGlobalCatalogConnectionAsync() { - await _semaphore.WaitAsync(); if (!_globalCatalogConnection.TryTake(out var connectionWrapper)) { var (success, connection, message) = await CreateNewConnection(true); if (!success) { //If we didn't get a connection, immediately release the semaphore so we don't have hanging ones - _semaphore.Release(); return (false, null, message); } @@ -82,7 +597,6 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c } public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) { - _semaphore.Release(); if (!connectionFaulted) { if (connectionWrapper.GlobalCatalog) { _globalCatalogConnection.Add(connectionWrapper); diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index 57bdf4b9..6f4025d1 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -29,7 +29,6 @@ namespace SharpHoundCommonLib { public class LdapUtils : ILdapUtils { //This cache is indexed by domain sid - private readonly ConcurrentDictionary _dcInfoCache = new(); private static readonly ConcurrentDictionary DomainCache = new(); private static readonly ConcurrentDictionary DomainControllers = new(); @@ -90,385 +89,18 @@ public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); } - public async IAsyncEnumerable> RangedRetrieval(string distinguishedName, - string attributeName, [EnumeratorCancellation] CancellationToken cancellationToken = new()) { - var domain = Helpers.DistinguishedNameToDomain(distinguishedName); - - var connectionResult = await _connectionPool.GetLdapConnection(domain, false); - if (!connectionResult.Success) { - yield return Result.Fail(connectionResult.Message); - yield break; - } - - var index = 0; - var step = 0; - - //Start by using * as our upper index, which will automatically give us the range size - var currentRange = $"{attributeName};range={index}-*"; - var complete = false; - - var queryParameters = new LdapQueryParameters { - DomainName = domain, - LDAPFilter = $"{attributeName}=*", - Attributes = new[] { currentRange }, - SearchScope = SearchScope.Base, - SearchBase = distinguishedName - }; - var connectionWrapper = connectionResult.ConnectionWrapper; - - if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield return Result.Fail("Failed to create search request"); - yield break; - } - - var queryRetryCount = 0; - var busyRetryCount = 0; - - LdapResult tempResult = null; - - while (!cancellationToken.IsCancellationRequested) { - SearchResponse response = null; - try { - response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); - } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { - busyRetryCount++; - var backoffDelay = GetNextBackoff(busyRetryCount); - await Task.Delay(backoffDelay, cancellationToken); - } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && - queryRetryCount < MaxRetries) { - queryRetryCount++; - _connectionPool.ReleaseConnection(connectionWrapper, true); - for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { - var backoffDelay = GetNextBackoff(retryCount); - await Task.Delay(backoffDelay, cancellationToken); - var (success, newConnectionWrapper, message) = - await _connectionPool.GetLdapConnection(domain, - false); - if (success) { - _log.LogDebug( - "RangedRetrieval - Recovered from ServerDown successfully, connection made to {NewServer}", - newConnectionWrapper.GetServer()); - connectionWrapper = newConnectionWrapper; - break; - } - - //If we hit our max retries for making a new connection, set tempResult so we can yield it after this logic - if (retryCount == MaxRetries - 1) { - _log.LogError( - "RangedRetrieval - Failed to get a new connection after ServerDown for path {Path}", - distinguishedName); - tempResult = - LdapResult.Fail( - "RangedRetrieval - Failed to get a new connection after ServerDown.", - queryParameters, le.ErrorCode); - } - } - } catch (LdapException le) { - tempResult = LdapResult.Fail( - $"Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", - queryParameters, le.ErrorCode); - } catch (Exception e) { - tempResult = - LdapResult.Fail($"Caught unrecoverable exception: {e.Message}", queryParameters); - } - - //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function - //We handle connection release in the relevant exception blocks - if (tempResult != null) { - if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { - _connectionPool.ReleaseConnection(connectionWrapper, true); - } else { - _connectionPool.ReleaseConnection(connectionWrapper); - } - - yield return tempResult; - yield break; - } - - if (response?.Entries.Count == 1) { - var entry = response.Entries[0]; - //We dont know the name of our attribute, but there should only be one, so we're safe to just use a loop here - foreach (string attr in entry.Attributes.AttributeNames) { - currentRange = attr; - complete = currentRange.IndexOf("*", 0, StringComparison.OrdinalIgnoreCase) > 0; - step = entry.Attributes[currentRange].Count; - } - - foreach (string dn in entry.Attributes[currentRange].GetValues(typeof(string))) { - yield return Result.Ok(dn); - index++; - } - - if (complete) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } - - currentRange = $"{attributeName};range={index}-{index + step}"; - searchRequest.Attributes.Clear(); - searchRequest.Attributes.Add(currentRange); - } else { - //I dont know what can cause a RR to have multiple entries, but its nothing good. Break out - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } - } - - _connectionPool.ReleaseConnection(connectionWrapper); + public IAsyncEnumerable> RangedRetrieval(string distinguishedName, + string attributeName, CancellationToken cancellationToken = new()) { + return _connectionPool.RangedRetrieval(distinguishedName, attributeName, cancellationToken); } - public async IAsyncEnumerable> Query(LdapQueryParameters queryParameters, - [EnumeratorCancellation] CancellationToken cancellationToken = new()) { - var setupResult = await SetupLdapQuery(queryParameters); - - if (!setupResult.Success) { - _log.LogInformation("Query - Failure during query setup: {Reason}\n{Info}", setupResult.Message, - queryParameters.GetQueryInfo()); - yield break; - } - - var searchRequest = setupResult.SearchRequest; - var connectionWrapper = setupResult.ConnectionWrapper; - - if (cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } - - var queryRetryCount = 0; - var busyRetryCount = 0; - LdapResult tempResult = null; - var querySuccess = false; - SearchResponse response = null; - while (!cancellationToken.IsCancellationRequested) { - try { - _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); - response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); - - if (response != null) { - querySuccess = true; - } else if (queryRetryCount == MaxRetries) { - tempResult = - LdapResult.Fail($"Failed to get a response after {MaxRetries} attempts", - queryParameters); - } else { - queryRetryCount++; - continue; - } - } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && - queryRetryCount < MaxRetries) { - /* - * A ServerDown exception indicates that our connection is no longer valid for one of many reasons. - * We'll want to release our connection back to the pool, but dispose it. We need a new connection, - * and because this is not a paged query, we can get this connection from anywhere. - */ - - //Increment our query retry count - queryRetryCount++; - _connectionPool.ReleaseConnection(connectionWrapper, true); - - for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { - var backoffDelay = GetNextBackoff(retryCount); - await Task.Delay(backoffDelay, cancellationToken); - var (success, newConnectionWrapper, message) = - await _connectionPool.GetLdapConnection(queryParameters.DomainName, - queryParameters.GlobalCatalog); - if (success) { - _log.LogDebug( - "Query - Recovered from ServerDown successfully, connection made to {NewServer}", - newConnectionWrapper.GetServer()); - connectionWrapper = newConnectionWrapper; - break; - } - - //If we hit our max retries for making a new connection, set tempResult so we can yield it after this logic - if (retryCount == MaxRetries - 1) { - _log.LogError("Query - Failed to get a new connection after ServerDown.\n{Info}", - queryParameters.GetQueryInfo()); - tempResult = - LdapResult.Fail( - "Query - Failed to get a new connection after ServerDown.", queryParameters); - } - } - } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { - /* - * If we get a busy error, we want to do an exponential backoff, but maintain the current connection - * The expectation is that given enough time, the server should stop being busy and service our query appropriately - */ - busyRetryCount++; - var backoffDelay = GetNextBackoff(busyRetryCount); - await Task.Delay(backoffDelay, cancellationToken); - } catch (LdapException le) { - tempResult = LdapResult.Fail( - $"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", - queryParameters); - } catch (Exception e) { - tempResult = - LdapResult.Fail($"Query - Caught unrecoverable exception: {e.Message}", - queryParameters); - } - - //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function - if (tempResult != null) { - if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { - _connectionPool.ReleaseConnection(connectionWrapper, true); - } else { - _connectionPool.ReleaseConnection(connectionWrapper); - } - - yield return tempResult; - yield break; - } - - //If we've successfully made our query, break out of the while loop - if (querySuccess) { - break; - } - } - - _connectionPool.ReleaseConnection(connectionWrapper); - foreach (SearchResultEntry entry in response.Entries) { - yield return LdapResult.Ok(new SearchResultEntryWrapper(entry)); - } + public IAsyncEnumerable> Query(LdapQueryParameters queryParameters, + CancellationToken cancellationToken = new()) { + return _connectionPool.Query(queryParameters, cancellationToken); } - public async IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, - [EnumeratorCancellation] CancellationToken cancellationToken = new()) { - var setupResult = await SetupLdapQuery(queryParameters); - - if (!setupResult.Success) { - _log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message, - queryParameters.GetQueryInfo()); - yield break; - } - - var searchRequest = setupResult.SearchRequest; - var connectionWrapper = setupResult.ConnectionWrapper; - var serverName = setupResult.Server; - - if (serverName == null) { - _log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible"); - } - - var pageControl = new PageResultRequestControl(500); - searchRequest.Controls.Add(pageControl); - - PageResultResponseControl pageResponse = null; - var busyRetryCount = 0; - var queryRetryCount = 0; - LdapResult tempResult = null; - - while (!cancellationToken.IsCancellationRequested) { - SearchResponse response = null; - try { - _log.LogTrace("Sending paged ldap request - {Info}", queryParameters.GetQueryInfo()); - response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); - if (response != null) { - pageResponse = (PageResultResponseControl)response.Controls - .Where(x => x is PageResultResponseControl).DefaultIfEmpty(null).FirstOrDefault(); - queryRetryCount = 0; - } else if (queryRetryCount == MaxRetries) { - tempResult = LdapResult.Fail( - $"PagedQuery - Failed to get a response after {MaxRetries} attempts", - queryParameters); - } else { - queryRetryCount++; - } - } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) { - /* - * If we dont have a servername, we're not going to be able to re-establish a connection here. Page cookies are only valid for the server they were generated on. Bail out. - */ - if (serverName == null) { - _log.LogError( - "PagedQuery - Received server down exception without a known servername. Unable to generate new connection\n{Info}", - queryParameters.GetQueryInfo()); - _connectionPool.ReleaseConnection(connectionWrapper, true); - yield break; - } - - /* - * Paged queries will not use the cached ldap connections, as the intention is to only have 1 or a couple of these queries running at once. - * The connection logic here is simplified accordingly - */ - _connectionPool.ReleaseConnection(connectionWrapper, true); - for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { - var backoffDelay = GetNextBackoff(retryCount); - await Task.Delay(backoffDelay, cancellationToken); - var (success, ldapConnectionWrapperNew, message) = - await _connectionPool.GetLdapConnectionForServer( - queryParameters.DomainName, serverName, queryParameters.GlobalCatalog); - - if (success) { - _log.LogDebug("PagedQuery - Recovered from ServerDown successfully"); - connectionWrapper = ldapConnectionWrapperNew; - break; - } - - if (retryCount == MaxRetries - 1) { - _log.LogError("PagedQuery - Failed to get a new connection after ServerDown.\n{Info}", - queryParameters.GetQueryInfo()); - tempResult = - LdapResult.Fail("Failed to get a new connection after serverdown", - queryParameters, le.ErrorCode); - } - } - } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { - /* - * If we get a busy error, we want to do an exponential backoff, but maintain the current connection - * The expectation is that given enough time, the server should stop being busy and service our query appropriately - */ - busyRetryCount++; - var backoffDelay = GetNextBackoff(busyRetryCount); - await Task.Delay(backoffDelay, cancellationToken); - } catch (LdapException le) { - tempResult = LdapResult.Fail( - $"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", - queryParameters, le.ErrorCode); - } catch (Exception e) { - tempResult = - LdapResult.Fail($"PagedQuery - Caught unrecoverable exception: {e.Message}", - queryParameters); - } - - if (tempResult != null) { - if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { - _connectionPool.ReleaseConnection(connectionWrapper, true); - } else { - _connectionPool.ReleaseConnection(connectionWrapper); - } - - yield return tempResult; - yield break; - } - - if (cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } - - //I'm not sure why this happens sometimes, but if we try the request again, it works sometimes, other times we get an exception - if (response == null || pageResponse == null) { - continue; - } - - foreach (SearchResultEntry entry in response.Entries) { - if (cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } - - yield return LdapResult.Ok(new SearchResultEntryWrapper(entry)); - } - - if (pageResponse.Cookie.Length == 0 || response.Entries.Count == 0 || - cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } - - pageControl.Cookie = pageResponse.Cookie; - } + public IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, CancellationToken cancellationToken = new()) { + return _connectionPool.PagedQuery(queryParameters, cancellationToken); } public async Task<(bool Success, TypedPrincipal Principal)> ResolveIDAndType( @@ -680,110 +312,6 @@ private static TimeSpan GetNextBackoff(int retryCount) { MaxBackoffDelay.TotalSeconds)); } - private bool CreateSearchRequest(LdapQueryParameters queryParameters, - LdapConnectionWrapper connectionWrapper, out SearchRequest searchRequest) { - string basePath; - if (!string.IsNullOrWhiteSpace(queryParameters.SearchBase)) { - basePath = queryParameters.SearchBase; - } else if (!connectionWrapper.GetSearchBase(queryParameters.NamingContext, out basePath)) { - string tempPath; - if (CallDsGetDcName(queryParameters.DomainName, out var info) && info != null) { - tempPath = Helpers.DomainNameToDistinguishedName(info.Value.DomainName); - connectionWrapper.SaveContext(queryParameters.NamingContext, basePath); - } else if (GetDomain(queryParameters.DomainName, out var domainObject)) { - tempPath = Helpers.DomainNameToDistinguishedName(domainObject.Name); - } else { - searchRequest = null; - return false; - } - - basePath = queryParameters.NamingContext switch { - NamingContext.Configuration => $"CN=Configuration,{tempPath}", - NamingContext.Schema => $"CN=Schema,CN=Configuration,{tempPath}", - NamingContext.Default => tempPath, - _ => throw new ArgumentOutOfRangeException() - }; - - connectionWrapper.SaveContext(queryParameters.NamingContext, basePath); - } - - if (string.IsNullOrWhiteSpace(queryParameters.SearchBase) && !string.IsNullOrWhiteSpace(queryParameters.RelativeSearchBase)) { - basePath = $"{queryParameters.RelativeSearchBase},{basePath}"; - } - - searchRequest = new SearchRequest(basePath, queryParameters.LDAPFilter, queryParameters.SearchScope, - queryParameters.Attributes); - searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope)); - if (queryParameters.IncludeDeleted) { - searchRequest.Controls.Add(new ShowDeletedControl()); - } - - if (queryParameters.IncludeSecurityDescriptor) { - searchRequest.Controls.Add(new SecurityDescriptorFlagControl { - SecurityMasks = SecurityMasks.Dacl | SecurityMasks.Owner - }); - } - - return true; - } - - private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControllerInfo? info) { - if (_dcInfoCache.TryGetValue(domainName.ToUpper().Trim(), out info)) return info != null; - - var apiResult = _nativeMethods.CallDsGetDcName(null, domainName, - (uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY | - NetAPIEnums.DSGETDCNAME_FLAGS.DS_RETURN_DNS_NAME | - NetAPIEnums.DSGETDCNAME_FLAGS.DS_DIRECTORY_SERVICE_REQUIRED)); - - if (apiResult.IsFailed) { - _dcInfoCache.TryAdd(domainName.ToUpper().Trim(), null); - return false; - } - - info = apiResult.Value; - return true; - } - - private async Task SetupLdapQuery(LdapQueryParameters queryParameters) { - var result = new LdapQuerySetupResult(); - var (success, connectionWrapper, message) = - await _connectionPool.GetLdapConnection(queryParameters.DomainName, queryParameters.GlobalCatalog); - if (!success) { - result.Success = false; - result.Message = $"Unable to create a connection: {message}"; - return result; - } - - //This should never happen as far as I know, so just checking for safety - if (connectionWrapper.Connection == null) { - result.Success = false; - result.Message = "Connection object is null"; - return result; - } - - if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) { - result.Success = false; - result.Message = "Failed to create search request"; - _connectionPool.ReleaseConnection(connectionWrapper); - return result; - } - - result.Server = connectionWrapper.GetServer(); - result.Success = true; - result.SearchRequest = searchRequest; - result.ConnectionWrapper = connectionWrapper; - return result; - } - - private SearchRequest CreateSearchRequest(string distinguishedName, string ldapFilter, - SearchScope searchScope, - string[] attributes) { - var searchRequest = new SearchRequest(distinguishedName, ldapFilter, - searchScope, attributes); - searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope)); - return searchRequest; - } - public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) { string domainSid; try { @@ -1431,6 +959,7 @@ await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) { public void SetLdapConfig(LdapConfig config) { _ldapConfig = config; + _log.LogInformation("New LDAP Config Set:\n {ConfigString}", config.ToString()); _connectionPool.Dispose(); _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); } diff --git a/src/CommonLib/Processors/ComputerSessionProcessor.cs b/src/CommonLib/Processors/ComputerSessionProcessor.cs index e0db075e..71034549 100644 --- a/src/CommonLib/Processors/ComputerSessionProcessor.cs +++ b/src/CommonLib/Processors/ComputerSessionProcessor.cs @@ -44,32 +44,54 @@ public ComputerSessionProcessor(ILdapUtils utils, string currentUserName = null, /// /// /// + /// /// public async Task ReadUserSessions(string computerName, string computerSid, - string computerDomain) { + string computerDomain, TimeSpan timeout = default) { + + if (timeout == default) { + timeout = TimeSpan.FromMinutes(2); + } var ret = new SessionAPIResult(); - NetAPIResult> result; _log.LogDebug("Running NetSessionEnum for {ObjectName}", computerName); - if (_doLocalAdminSessionEnum) { - // If we are authenticating using a local admin, we need to impersonate for this - using (new Impersonator(_localAdminUsername, ".", _localAdminPassword, - LogonType.LOGON32_LOGON_NEW_CREDENTIALS, LogonProvider.LOGON32_PROVIDER_WINNT50)) { - result = _nativeMethods.NetSessionEnum(computerName); - } + var apiTask = Task.Run(() => { + NetAPIResult> result; + if (_doLocalAdminSessionEnum) { + // If we are authenticating using a local admin, we need to impersonate for this + using (new Impersonator(_localAdminUsername, ".", _localAdminPassword, + LogonType.LOGON32_LOGON_NEW_CREDENTIALS, LogonProvider.LOGON32_PROVIDER_WINNT50)) { + result = _nativeMethods.NetSessionEnum(computerName); + } - if (result.IsFailed) { - // Fall back to default User - _log.LogDebug( - "NetSessionEnum failed on {ComputerName} with local admin credentials: {Status}. Fallback to default user.", - computerName, result.Status); + if (result.IsFailed) { + // Fall back to default User + _log.LogDebug( + "NetSessionEnum failed on {ComputerName} with local admin credentials: {Status}. Fallback to default user.", + computerName, result.Status); + result = _nativeMethods.NetSessionEnum(computerName); + } + } else { result = _nativeMethods.NetSessionEnum(computerName); } - } else { - result = _nativeMethods.NetSessionEnum(computerName); + + return result; + }); + + if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) { + await SendComputerStatus(new CSVComputerStatus { + Status = "Timeout", + Task = "NetSessionEnum", + ComputerName = computerName + }); + ret.Collected = false; + ret.FailureReason = "Timeout"; + return ret; } + var result = apiTask.Result; + if (result.IsFailed) { await SendComputerStatus(new CSVComputerStatus { Status = result.Status.ToString(), @@ -153,33 +175,54 @@ await SendComputerStatus(new CSVComputerStatus { /// /// /// + /// /// public async Task ReadUserSessionsPrivileged(string computerName, - string computerSamAccountName, string computerSid) { + string computerSamAccountName, string computerSid, TimeSpan timeout = default) { var ret = new SessionAPIResult(); - NetAPIResult> - result; + if (timeout == default) { + timeout = TimeSpan.FromMinutes(2); + } _log.LogDebug("Running NetWkstaUserEnum for {ObjectName}", computerName); - if (_doLocalAdminSessionEnum) { - // If we are authenticating using a local admin, we need to impersonate for this - using (new Impersonator(_localAdminUsername, ".", _localAdminPassword, - LogonType.LOGON32_LOGON_NEW_CREDENTIALS, LogonProvider.LOGON32_PROVIDER_WINNT50)) { - result = _nativeMethods.NetWkstaUserEnum(computerName); - } + var apiTask = Task.Run(() => { + NetAPIResult> + result; + if (_doLocalAdminSessionEnum) { + // If we are authenticating using a local admin, we need to impersonate for this + using (new Impersonator(_localAdminUsername, ".", _localAdminPassword, + LogonType.LOGON32_LOGON_NEW_CREDENTIALS, LogonProvider.LOGON32_PROVIDER_WINNT50)) { + result = _nativeMethods.NetWkstaUserEnum(computerName); + } - if (result.IsFailed) { - // Fall back to default User - _log.LogDebug( - "NetWkstaUserEnum failed on {ComputerName} with local admin credentials: {Status}. Fallback to default user.", - computerName, result.Status); + if (result.IsFailed) { + // Fall back to default User + _log.LogDebug( + "NetWkstaUserEnum failed on {ComputerName} with local admin credentials: {Status}. Fallback to default user.", + computerName, result.Status); + result = _nativeMethods.NetWkstaUserEnum(computerName); + } + } else { result = _nativeMethods.NetWkstaUserEnum(computerName); } - } else { - result = _nativeMethods.NetWkstaUserEnum(computerName); + + return result; + }); + + if (await Task.WhenAny(Task.Delay(timeout), apiTask) != apiTask) { + await SendComputerStatus(new CSVComputerStatus { + Status = "Timeout", + Task = "NetWkstaUserEnum", + ComputerName = computerName + }); + ret.Collected = false; + ret.FailureReason = "Timeout"; + return ret; } + var result = apiTask.Result; + if (result.IsFailed) { await SendComputerStatus(new CSVComputerStatus { Status = result.Status.ToString(), diff --git a/src/CommonLib/Processors/PortScanner.cs b/src/CommonLib/Processors/PortScanner.cs index 5f66ca02..e075d986 100644 --- a/src/CommonLib/Processors/PortScanner.cs +++ b/src/CommonLib/Processors/PortScanner.cs @@ -28,7 +28,7 @@ public PortScanner(ILogger log = null) /// /// Timeout in milliseconds /// True if port is open, otherwise false - public virtual async Task CheckPort(string hostname, int port = 445, int timeout = 500) + public virtual async Task CheckPort(string hostname, int port = 445, int timeout = 10000) { var key = new PingCacheKey {