diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index c8461f61..cc92b616 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -24,8 +24,10 @@ using SearchScope = System.DirectoryServices.Protocols.SearchScope; using SecurityMasks = System.DirectoryServices.Protocols.SecurityMasks; -namespace SharpHoundCommonLib { - public class LdapUtils : ILdapUtils { +namespace SharpHoundCommonLib +{ + public class LdapUtils : ILdapUtils + { //This cache is indexed by domain sid private readonly ConcurrentDictionary _dcInfoCache = new(); private static readonly ConcurrentDictionary DomainCache = new(); @@ -68,435 +70,382 @@ private static readonly ConcurrentDictionary 0x00, 0x01 }; - private class ResolvedWellKnownPrincipal { + private class ResolvedWellKnownPrincipal + { public string DomainName { get; set; } public string WkpId { get; set; } } - public LdapUtils() { + public LdapUtils() + { _nativeMethods = new NativeMethods(); _portScanner = new PortScanner(); _log = Logging.LogProvider.CreateLogger("LDAPUtils"); _connectionPool = new ConnectionPoolManager(_ldapConfig, _log); } - public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null) { + public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null) + { _nativeMethods = nativeMethods ?? new NativeMethods(); _portScanner = scanner ?? new PortScanner(); _log = log ?? Logging.LogProvider.CreateLogger("LDAPUtils"); - _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner:_portScanner); + _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); } - public async IAsyncEnumerable> RangedRetrieval(string distinguishedName, - string attributeName, [EnumeratorCancellation] CancellationToken cancellationToken = new()) { + public async IAsyncEnumerable> RangedRetrieval( + string distinguishedName, + string attributeName, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { var domain = Helpers.DistinguishedNameToDomain(distinguishedName); - - var connectionResult = await _connectionPool.GetLdapConnection(domain, false); - if (!connectionResult.Success) { + var connectionResult = await _connectionPool.GetLdapConnection(domain, globalCatalog: false); + if (!connectionResult.Success) + { yield return Result.Fail(connectionResult.Message); yield break; } - var index = 0; - var step = 0; + var connectionWrapper = connectionResult.ConnectionWrapper; + var queryParameters = CreateQueryParameters(domain, attributeName, distinguishedName); + if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) + { + _connectionPool.ReleaseConnection(connectionWrapper); + yield return Result.Fail("Failed to create search request"); + yield break; + } + + await foreach (var result in ExecuteRangedRetrieval(connectionWrapper, searchRequest, queryParameters, cancellationToken)) + { + yield return result; + } - //Start by using * as our upper index, which will automatically give us the range size - var currentRange = $"{attributeName};range={index}-*"; - var complete = false; + _connectionPool.ReleaseConnection(connectionWrapper); + } - var queryParameters = new LdapQueryParameters { + private LdapQueryParameters CreateQueryParameters(string domain, string attributeName, string distinguishedName) + { + return new LdapQueryParameters + { DomainName = domain, LDAPFilter = $"{attributeName}=*", - Attributes = new[] { currentRange }, + //Start by using * as our upper index, which will automatically give us the range size + Attributes = new[] { $"{attributeName};range=0-*" }, SearchScope = SearchScope.Base, SearchBase = distinguishedName }; - var connectionWrapper = connectionResult.ConnectionWrapper; + } - if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield return Result.Fail("Failed to create search request"); - yield break; - } + private async IAsyncEnumerable> ExecuteRangedRetrieval( + LdapConnectionWrapper connectionWrapper, + SearchRequest searchRequest, + LdapQueryParameters queryParameters, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = 0; + int step = 0; + bool complete = false; - var queryRetryCount = 0; - var busyRetryCount = 0; + while (!cancellationToken.IsCancellationRequested && !complete) + { + var response = await ExecuteSearchRequest(connectionWrapper, searchRequest, queryParameters, cancellationToken); + if (!response.IsSuccess) + { + yield return Result.Fail(response.Error); + yield break; + } - LdapResult tempResult = null; + var entry = response.Value.Entries[0]; + complete = UpdateRangeInfo(entry, ref index, ref searchRequest); - while (!cancellationToken.IsCancellationRequested) { - SearchResponse response = null; - try { - response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + foreach (string dn in entry.Attributes[searchRequest.Attributes[0]].GetValues(typeof(string))) + { + yield return Result.Ok(dn); } - catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { - busyRetryCount++; - var backoffDelay = GetNextBackoff(busyRetryCount); - await Task.Delay(backoffDelay, cancellationToken); + } + } + + private async Task> ExecuteSearchRequest( + LdapConnectionWrapper connectionWrapper, + SearchRequest searchRequest, + LdapQueryParameters queryParameters, + CancellationToken cancellationToken) + { + int busyRetryCount = 0; + int queryRetryCount = 0; + + while (!cancellationToken.IsCancellationRequested) + { + try + { + var response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + return Result.Ok(response); + } + catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) + { + await HandleBusyException(++busyRetryCount); } - catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && - queryRetryCount < MaxRetries) { + catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && queryRetryCount < MaxRetries) + { queryRetryCount++; - _connectionPool.ReleaseConnection(connectionWrapper, true); - for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { - var backoffDelay = GetNextBackoff(retryCount); - await Task.Delay(backoffDelay, cancellationToken); - var (success, newConnectionWrapper, message) = - await _connectionPool.GetLdapConnection(domain, - false); - if (success) { - _log.LogDebug( - "RangedRetrieval - Recovered from ServerDown successfully, connection made to {NewServer}", - newConnectionWrapper.GetServer()); - connectionWrapper = newConnectionWrapper; - break; - } - - //If we hit our max retries for making a new connection, set tempResult so we can yield it after this logic - if (retryCount == MaxRetries - 1) { - _log.LogError( - "RangedRetrieval - Failed to get a new connection after ServerDown for path {Path}", - distinguishedName); - tempResult = - LdapResult.Fail( - "RangedRetrieval - Failed to get a new connection after ServerDown.", queryParameters, le.ErrorCode); - } + var result = await HandleServerDownException(connectionWrapper, queryParameters.DomainName); + if (!result.IsSuccess) + { + return Result.Fail(result.Error); } + connectionWrapper = result.Value; } - catch (LdapException le) { - tempResult = LdapResult.Fail( - $"Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", queryParameters, le.ErrorCode); + catch (LdapException le) + { + return Result.Fail($"Unrecoverable LDAP exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})"); } - catch (Exception e) { - tempResult = - LdapResult.Fail($"Caught unrecoverable exception: {e.Message}", queryParameters); + catch (Exception e) + { + return Result.Fail($"Unrecoverable exception: {e.Message}"); } + } - //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function - //We handle connection release in the relevant exception blocks - if (tempResult != null) { - if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { - _connectionPool.ReleaseConnection(connectionWrapper, true); - } - else { - _connectionPool.ReleaseConnection(connectionWrapper); - } - yield return tempResult; - yield break; - } + return Result.Fail("Cancellation requested, exiting."); + } - if (response?.Entries.Count == 1) { - var entry = response.Entries[0]; - //We dont know the name of our attribute, but there should only be one, so we're safe to just use a loop here - foreach (string attr in entry.Attributes.AttributeNames) { - currentRange = attr; - complete = currentRange.IndexOf("*", 0, StringComparison.OrdinalIgnoreCase) > 0; - step = entry.Attributes[currentRange].Count; - } + private async Task HandleBusyException(int retryCount) + { + var backoffDelay = GetNextBackoff(retryCount); + await Task.Delay(backoffDelay); + } - foreach (string dn in entry.Attributes[currentRange].GetValues(typeof(string))) { - yield return Result.Ok(dn); - index++; - } + private async Task> HandleServerDownException(LdapConnectionWrapper oldConnection, string domain) + { + _connectionPool.ReleaseConnection(oldConnection, connectionFaulted: true); - if (complete) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } + for (int retryCount = 0; retryCount < MaxRetries; retryCount++) + { + var backoffDelay = GetNextBackoff(retryCount); + await Task.Delay(backoffDelay); - currentRange = $"{attributeName};range={index}-{index + step}"; - searchRequest.Attributes.Clear(); - searchRequest.Attributes.Add(currentRange); - } - else { - //I dont know what can cause a RR to have multiple entries, but its nothing good. Break out - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; + var (success, newConnectionWrapper, message) = await _connectionPool.GetLdapConnection(domain, false); + if (success) + { + _log.LogDebug("RangedRetrieval - Recovered from ServerDown successfully, connection made to {NewServer}", newConnectionWrapper.GetServer()); + return Result.Ok(newConnectionWrapper); } } - - _connectionPool.ReleaseConnection(connectionWrapper); + + _log.LogError("RangedRetrieval - Failed to get a new connection after ServerDown for domain {Domain}", domain); + return Result.Fail("Failed to get a new connection after ServerDown."); } - public async IAsyncEnumerable> Query(LdapQueryParameters queryParameters, - [EnumeratorCancellation] CancellationToken cancellationToken = new()) { - var setupResult = await SetupLdapQuery(queryParameters); + private bool UpdateRangeInfo(SearchResultEntry entry, ref int index, ref SearchRequest searchRequest) + { + string currentRange = entry.Attributes.AttributeNames.First(); + bool complete = currentRange.IndexOf("*", 0, StringComparison.OrdinalIgnoreCase) > 0; + int step = entry.Attributes[currentRange].Count; - if (!setupResult.Success) { - _log.LogInformation("Query - Failure during query setup: {Reason}\n{Info}", setupResult.Message, - queryParameters.GetQueryInfo()); - yield break; + index += step; + if (!complete) + { + string newRange = $"{currentRange.Split(';')[0]};range={index}-{index + step}"; + searchRequest.Attributes.Clear(); + searchRequest.Attributes.Add(newRange); } - var searchRequest = setupResult.SearchRequest; - var connectionWrapper = setupResult.ConnectionWrapper; + return complete; + } - if (cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); + public async IAsyncEnumerable> Query( + LdapQueryParameters queryParameters, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var setupResult = await SetupLdapQuery(queryParameters); + if (!setupResult.Success) + { + _log.LogInformation("Query - Failure during query setup: {Reason}\n{Info}", setupResult.SearchRequest, queryParameters.GetQueryInfo()); yield break; } - var queryRetryCount = 0; - var busyRetryCount = 0; - LdapResult tempResult = null; - var querySuccess = false; - SearchResponse response = null; - while (!cancellationToken.IsCancellationRequested) { - try { - _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); - response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); - - if (response != null) { - querySuccess = true; - } - else if (queryRetryCount == MaxRetries) { - tempResult = - LdapResult.Fail($"Failed to get a response after {MaxRetries} attempts", - queryParameters); - } - else { - queryRetryCount++; - continue; - } - } - catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && - queryRetryCount < MaxRetries) { - /* - * A ServerDown exception indicates that our connection is no longer valid for one of many reasons. - * We'll want to release our connection back to the pool, but dispose it. We need a new connection, - * and because this is not a paged query, we can get this connection from anywhere. - */ - - //Increment our query retry count - queryRetryCount++; - _connectionPool.ReleaseConnection(connectionWrapper, true); - - for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { - var backoffDelay = GetNextBackoff(retryCount); - await Task.Delay(backoffDelay, cancellationToken); - var (success, newConnectionWrapper, message) = - await _connectionPool.GetLdapConnection(queryParameters.DomainName, - queryParameters.GlobalCatalog); - if (success) { - _log.LogDebug("Query - Recovered from ServerDown successfully, connection made to {NewServer}", - newConnectionWrapper.GetServer()); - connectionWrapper = newConnectionWrapper; - break; - } + if (cancellationToken.IsCancellationRequested) + { + _connectionPool.ReleaseConnection(setupResult.ConnectionWrapper); + yield break; + } - //If we hit our max retries for making a new connection, set tempResult so we can yield it after this logic - if (retryCount == MaxRetries - 1) { - _log.LogError("Query - Failed to get a new connection after ServerDown.\n{Info}", - queryParameters.GetQueryInfo()); - tempResult = - LdapResult.Fail( - "Query - Failed to get a new connection after ServerDown.", queryParameters); - } - } - } - catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { - /* - * If we get a busy error, we want to do an exponential backoff, but maintain the current connection - * The expectation is that given enough time, the server should stop being busy and service our query appropriately - */ - busyRetryCount++; - var backoffDelay = GetNextBackoff(busyRetryCount); - await Task.Delay(backoffDelay, cancellationToken); - } - catch (LdapException le) { - tempResult = LdapResult.Fail( - $"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", - queryParameters); - } - catch (Exception e) { - tempResult = - LdapResult.Fail($"Query - Caught unrecoverable exception: {e.Message}", - queryParameters); - } - - //If we have a tempResult set it means we hit an error we couldn't recover from, so yield that result and then break out of the function - if (tempResult != null) { - if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { - _connectionPool.ReleaseConnection(connectionWrapper, true); - } - else { - _connectionPool.ReleaseConnection(connectionWrapper); - } - yield return tempResult; - yield break; - } + var queryResult = await ExecuteQuery(setupResult.SearchRequest, setupResult.ConnectionWrapper, queryParameters, cancellationToken); + _connectionPool.ReleaseConnection(setupResult.ConnectionWrapper, queryResult.ErrorCode == (int)LdapErrorCodes.ServerDown); - //If we've successfully made our query, break out of the while loop - if (querySuccess) { - break; - } + if (!queryResult.Success) + { + yield return LdapResult.Fail(queryResult.ErrorMessage, queryParameters, queryResult.ErrorCode); + yield break; } - - _connectionPool.ReleaseConnection(connectionWrapper); - foreach (SearchResultEntry entry in response.Entries) { + + foreach (SearchResultEntry entry in queryResult.Response.Entries) + { yield return LdapResult.Ok(new SearchResultEntryWrapper(entry, this)); } } - public async IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, - [EnumeratorCancellation] CancellationToken cancellationToken = new()) { + public async IAsyncEnumerable> PagedQuery( + LdapQueryParameters queryParameters, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { var setupResult = await SetupLdapQuery(queryParameters); - - if (!setupResult.Success) { - _log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message, - queryParameters.GetQueryInfo()); + if (!setupResult.Success) + { + _log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.SearchRequest, queryParameters.GetQueryInfo()); yield break; } - var searchRequest = setupResult.SearchRequest; - var connectionWrapper = setupResult.ConnectionWrapper; - var serverName = setupResult.Server; - - if (serverName == null) { - _log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible"); - } - var pageControl = new PageResultRequestControl(500); - searchRequest.Controls.Add(pageControl); + setupResult.SearchRequest.Controls.Add(pageControl); - PageResultResponseControl pageResponse = null; - var busyRetryCount = 0; - var queryRetryCount = 0; - LdapResult tempResult = null; + while (!cancellationToken.IsCancellationRequested) + { + var queryResult = await ExecutePagedQuery(setupResult.SearchRequest, setupResult.ConnectionWrapper, setupResult.Server, queryParameters, cancellationToken); - while (!cancellationToken.IsCancellationRequested) { - SearchResponse response = null; - try { - _log.LogTrace("Sending paged ldap request - {Info}", queryParameters.GetQueryInfo()); - response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); - if (response != null) { - pageResponse = (PageResultResponseControl)response.Controls - .Where(x => x is PageResultResponseControl).DefaultIfEmpty(null).FirstOrDefault(); - queryRetryCount = 0; - } - else if (queryRetryCount == MaxRetries) { - tempResult = LdapResult.Fail( - $"PagedQuery - Failed to get a response after {MaxRetries} attempts", - queryParameters); - } - else { - queryRetryCount++; - } + if (!queryResult.Success) + { + _connectionPool.ReleaseConnection(setupResult.ConnectionWrapper, queryResult.ErrorCode == (int)LdapErrorCodes.ServerDown); + yield return LdapResult.Fail(queryResult.ErrorMessage, queryParameters, queryResult.ErrorCode); + yield break; } - catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) { - /* - * If we dont have a servername, we're not going to be able to re-establish a connection here. Page cookies are only valid for the server they were generated on. Bail out. - */ - if (serverName == null) { - _log.LogError( - "PagedQuery - Received server down exception without a known servername. Unable to generate new connection\n{Info}", - queryParameters.GetQueryInfo()); - _connectionPool.ReleaseConnection(connectionWrapper, true); + + foreach (SearchResultEntry entry in queryResult.Response.Entries) + { + if (cancellationToken.IsCancellationRequested) + { + _connectionPool.ReleaseConnection(setupResult.ConnectionWrapper); yield break; } - /* - * Paged queries will not use the cached ldap connections, as the intention is to only have 1 or a couple of these queries running at once. - * The connection logic here is simplified accordingly - */ - _connectionPool.ReleaseConnection(connectionWrapper, true); - for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { - var backoffDelay = GetNextBackoff(retryCount); - await Task.Delay(backoffDelay, cancellationToken); - var (success, ldapConnectionWrapperNew, message) = await _connectionPool.GetLdapConnectionForServer( - queryParameters.DomainName, serverName, queryParameters.GlobalCatalog); - - if (success) { - _log.LogDebug("PagedQuery - Recovered from ServerDown successfully"); - connectionWrapper = ldapConnectionWrapperNew; - break; - } + yield return LdapResult.Ok(new SearchResultEntryWrapper(entry, this)); + } - if (retryCount == MaxRetries - 1) { - _log.LogError("PagedQuery - Failed to get a new connection after ServerDown.\n{Info}", - queryParameters.GetQueryInfo()); - tempResult = - LdapResult.Fail("Failed to get a new connection after serverdown", - queryParameters, le.ErrorCode); - } - } + var pageResponse = (PageResultResponseControl)queryResult.Response.Controls + .FirstOrDefault(x => x is PageResultResponseControl); + + if (pageResponse?.Cookie.Length == 0 || queryResult.Response.Entries.Count == 0) + { + _connectionPool.ReleaseConnection(setupResult.ConnectionWrapper); + yield break; } - catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { - /* - * If we get a busy error, we want to do an exponential backoff, but maintain the current connection - * The expectation is that given enough time, the server should stop being busy and service our query appropriately - */ - busyRetryCount++; - var backoffDelay = GetNextBackoff(busyRetryCount); - await Task.Delay(backoffDelay, cancellationToken); - } - catch (LdapException le) { - tempResult = LdapResult.Fail( - $"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", - queryParameters, le.ErrorCode); - } - catch (Exception e) { - tempResult = - LdapResult.Fail($"PagedQuery - Caught unrecoverable exception: {e.Message}", - queryParameters); - } - - if (tempResult != null) { - if (tempResult.ErrorCode == (int)LdapErrorCodes.ServerDown) { - _connectionPool.ReleaseConnection(connectionWrapper, true); + + pageControl.Cookie = pageResponse.Cookie; + } + + _connectionPool.ReleaseConnection(setupResult.ConnectionWrapper); + } + + private async Task ExecuteQuery(SearchRequest searchRequest, LdapConnectionWrapper connectionWrapper, LdapQueryParameters queryParameters, CancellationToken cancellationToken) + { + int queryRetryCount = 0; + int busyRetryCount = 0; + + while (!cancellationToken.IsCancellationRequested) + { + try + { + _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); + var response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + return new QueryResult { Success = true, Response = response }; + } + catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && queryRetryCount < MaxRetries) + { + var newConnection = await HandleServerDownException(connectionWrapper, null); + if (newConnection != null) + { + connectionWrapper = newConnection.Value; } - else { - _connectionPool.ReleaseConnection(connectionWrapper); + else + { + return new QueryResult { Success = false, ErrorMessage = "Failed to get a new connection after ServerDown.", ErrorCode = le.ErrorCode }; } - yield return tempResult; - yield break; } - - if (cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; + catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) + { + await HandleBusyException(++busyRetryCount); } - - //I'm not sure why this happens sometimes, but if we try the request again, it works sometimes, other times we get an exception - if (response == null || pageResponse == null) { - continue; + catch (LdapException le) + { + return new QueryResult { Success = false, ErrorMessage = $"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", ErrorCode = le.ErrorCode }; + } + catch (Exception e) + { + return new QueryResult { Success = false, ErrorMessage = $"Query - Caught unrecoverable exception: {e.Message}" }; } + } - foreach (SearchResultEntry entry in response.Entries) { - if (cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; - } + return new QueryResult { Success = false, ErrorMessage = "Query cancelled" }; + } - yield return LdapResult.Ok(new SearchResultEntryWrapper(entry, this)); + private async Task ExecutePagedQuery(SearchRequest searchRequest, LdapConnectionWrapper connectionWrapper, string serverName, LdapQueryParameters queryParameters, CancellationToken cancellationToken) + { + int busyRetryCount = 0; + int queryRetryCount = 0; + + while (!cancellationToken.IsCancellationRequested) + { + try + { + _log.LogTrace("Sending paged ldap request - {Info}", queryParameters.GetQueryInfo()); + var response = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + return new QueryResult { Success = true, Response = response }; } + catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) + { + if (string.IsNullOrEmpty(serverName)) + { + _log.LogError("PagedQuery - Received server down exception without a known servername. Unable to generate new connection\n{Info}", queryParameters.GetQueryInfo()); + return new QueryResult { Success = false, ErrorMessage = "ServerDown exception without known server name", ErrorCode = le.ErrorCode }; + } - if (pageResponse.Cookie.Length == 0 || response.Entries.Count == 0 || - cancellationToken.IsCancellationRequested) { - _connectionPool.ReleaseConnection(connectionWrapper); - yield break; + var newConnection = await HandleServerDownException(connectionWrapper, serverName); + if (newConnection != null) + { + connectionWrapper = newConnection.Value; + } + else + { + return new QueryResult { Success = false, ErrorMessage = "Failed to get a new connection after ServerDown.", ErrorCode = le.ErrorCode }; + } + } + catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) + { + await HandleBusyException(++busyRetryCount); + } + catch (LdapException le) + { + return new QueryResult { Success = false, ErrorMessage = $"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", ErrorCode = le.ErrorCode }; + } + catch (Exception e) + { + return new QueryResult { Success = false, ErrorMessage = $"PagedQuery - Caught unrecoverable exception: {e.Message}" }; } - - pageControl.Cookie = pageResponse.Cookie; } + + return new QueryResult { Success = false, ErrorMessage = "PagedQuery cancelled" }; } + public async Task<(bool Success, TypedPrincipal Principal)> ResolveIDAndType(SecurityIdentifier securityIdentifier, - string objectDomain) { + string objectDomain) + { return await ResolveIDAndType(securityIdentifier.Value, objectDomain); } public async Task<(bool Success, TypedPrincipal Principal)> - ResolveIDAndType(string identifier, string objectDomain) { - if (identifier.Contains("0ACNF")) { + ResolveIDAndType(string identifier, string objectDomain) + { + if (identifier.Contains("0ACNF")) + { return (false, new TypedPrincipal(identifier, Label.Base)); } - if (await GetWellKnownPrincipal(identifier, objectDomain) is (true, var principal)) { + if (await GetWellKnownPrincipal(identifier, objectDomain) is (true, var principal)) + { return (true, principal); } - if (identifier.StartsWith("S-")) { + if (identifier.StartsWith("S-")) + { var result = await LookupSidType(identifier, objectDomain); return (result.Success, new TypedPrincipal(identifier, result.Type)); } @@ -505,52 +454,65 @@ public async IAsyncEnumerable> PagedQuery(LdapQue return (success, new TypedPrincipal(identifier, type)); } - private async Task<(bool Success, Label Type)> LookupSidType(string sid, string domain) { - if (Cache.GetIDType(sid, out var type)) { + private async Task<(bool Success, Label Type)> LookupSidType(string sid, string domain) + { + if (Cache.GetIDType(sid, out var type)) + { return (true, type); } var tempDomain = domain; - if (await GetDomainNameFromSid(sid) is (true, var domainName)) { + if (await GetDomainNameFromSid(sid) is (true, var domainName)) + { tempDomain = domainName; } - var result = await Query(new LdapQueryParameters() { + var result = await Query(new LdapQueryParameters() + { DomainName = tempDomain, LDAPFilter = CommonFilters.SpecificSID(sid), Attributes = CommonProperties.TypeResolutionProps }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { type = result.Value.GetLabel(); Cache.AddType(sid, type); return (true, type); } - try { + try + { var entry = new DirectoryEntry($"LDAP://"); - if (entry.GetLabel(out type)) { + if (entry.GetLabel(out type)) + { Cache.AddType(sid, type); return (true, type); } } - catch { + catch + { //pass } - using (var ctx = new PrincipalContext(ContextType.Domain)) { - try { + using (var ctx = new PrincipalContext(ContextType.Domain)) + { + try + { var principal = Principal.FindByIdentity(ctx, IdentityType.Sid, sid); - if (principal != null) { + if (principal != null) + { var entry = (DirectoryEntry)principal.GetUnderlyingObject(); - if (entry.GetLabel(out type)) { + if (entry.GetLabel(out type)) + { Cache.AddType(sid, type); return (true, type); } } } - catch { + catch + { //pass } } @@ -558,46 +520,58 @@ public async IAsyncEnumerable> PagedQuery(LdapQue return (false, Label.Base); } - private async Task<(bool Success, Label type)> LookupGuidType(string guid, string domain) { - if (Cache.GetIDType(guid, out var type)) { + private async Task<(bool Success, Label type)> LookupGuidType(string guid, string domain) + { + if (Cache.GetIDType(guid, out var type)) + { return (true, type); } - var result = await Query(new LdapQueryParameters() { + var result = await Query(new LdapQueryParameters() + { DomainName = domain, LDAPFilter = CommonFilters.SpecificGUID(guid), Attributes = CommonProperties.TypeResolutionProps }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { type = result.Value.GetLabel(); Cache.AddType(guid, type); return (true, type); } - try { + try + { var entry = new DirectoryEntry($"LDAP://"); - if (entry.GetLabel(out type)) { + if (entry.GetLabel(out type)) + { Cache.AddType(guid, type); return (true, type); } } - catch { + catch + { //pass } - using (var ctx = new PrincipalContext(ContextType.Domain)) { - try { + using (var ctx = new PrincipalContext(ContextType.Domain)) + { + try + { var principal = Principal.FindByIdentity(ctx, IdentityType.Guid, guid); - if (principal != null) { + if (principal != null) + { var entry = (DirectoryEntry)principal.GetUnderlyingObject(); - if (entry.GetLabel(out type)) { + if (entry.GetLabel(out type)) + { Cache.AddType(guid, type); return (true, type); } } } - catch { + catch + { //pass } } @@ -606,15 +580,18 @@ public async IAsyncEnumerable> PagedQuery(LdapQue } public async Task<(bool Success, TypedPrincipal WellKnownPrincipal)> GetWellKnownPrincipal( - string securityIdentifier, string objectDomain) { - if (!WellKnownPrincipal.GetWellKnownPrincipal(securityIdentifier, out var wellKnownPrincipal)) { + string securityIdentifier, string objectDomain) + { + if (!WellKnownPrincipal.GetWellKnownPrincipal(securityIdentifier, out var wellKnownPrincipal)) + { return (false, null); } var (newIdentifier, newDomain) = await GetWellKnownPrincipalObjectIdentifier(securityIdentifier, objectDomain); wellKnownPrincipal.ObjectIdentifier = newIdentifier; - SeenWellKnownPrincipals.TryAdd(wellKnownPrincipal.ObjectIdentifier, new ResolvedWellKnownPrincipal { + SeenWellKnownPrincipals.TryAdd(wellKnownPrincipal.ObjectIdentifier, new ResolvedWellKnownPrincipal + { DomainName = newDomain, WkpId = securityIdentifier }); @@ -623,20 +600,24 @@ public async IAsyncEnumerable> PagedQuery(LdapQue } private async Task<(string ObjectID, string Domain)> GetWellKnownPrincipalObjectIdentifier( - string securityIdentifier, string domain) { + string securityIdentifier, string domain) + { if (!WellKnownPrincipal.GetWellKnownPrincipal(securityIdentifier, out _)) return (securityIdentifier, string.Empty); - if (!securityIdentifier.Equals("S-1-5-9", StringComparison.OrdinalIgnoreCase)) { + if (!securityIdentifier.Equals("S-1-5-9", StringComparison.OrdinalIgnoreCase)) + { var tempDomain = domain; - if (GetDomain(tempDomain, out var domainObject) && domainObject.Name != null) { + if (GetDomain(tempDomain, out var domainObject) && domainObject.Name != null) + { tempDomain = domainObject.Name; } return ($"{tempDomain}-{securityIdentifier}".ToUpper(), tempDomain); } - if (await GetForest(domain) is (true, var forest)) { + if (await GetForest(domain) is (true, var forest)) + { return ($"{forest}-{securityIdentifier}".ToUpper(), forest); } @@ -644,24 +625,30 @@ public async IAsyncEnumerable> PagedQuery(LdapQue return ($"UNKNOWN-{securityIdentifier}", "UNKNOWN"); } - private async Task<(bool Success, string ForestName)> GetForest(string domain) { - if (DomainToForestCache.TryGetValue(domain, out var cachedForest)) { + private async Task<(bool Success, string ForestName)> GetForest(string domain) + { + if (DomainToForestCache.TryGetValue(domain, out var cachedForest)) + { return (true, cachedForest); } - if (GetDomain(domain, out var domainObject)) { - try { + if (GetDomain(domain, out var domainObject)) + { + try + { var forestName = domainObject.Forest.Name.ToUpper(); DomainToForestCache.TryAdd(domain, forestName); return (true, forestName); } - catch { + catch + { //pass } } var (success, forest) = await GetForestFromLdap(domain); - if (success) { + if (success) + { DomainToForestCache.TryAdd(domain, forest); return (true, forest); } @@ -669,8 +656,10 @@ public async IAsyncEnumerable> PagedQuery(LdapQue return (false, null); } - private async Task<(bool Success, string ForestName)> GetForestFromLdap(string domain) { - var queryParameters = new LdapQueryParameters { + private async Task<(bool Success, string ForestName)> GetForestFromLdap(string domain) + { + var queryParameters = new LdapQueryParameters + { Attributes = new[] { LDAPProperties.RootDomainNamingContext }, SearchScope = SearchScope.Base, DomainName = domain, @@ -678,9 +667,11 @@ public async IAsyncEnumerable> PagedQuery(LdapQue }; var result = await Query(queryParameters).FirstAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { var rdn = result.Value.GetProperty(LDAPProperties.RootDomainNamingContext); - if (!string.IsNullOrEmpty(rdn)) { + if (!string.IsNullOrEmpty(rdn)) + { return (true, Helpers.DistinguishedNameToDomain(rdn).ToUpper()); } } @@ -688,33 +679,41 @@ public async IAsyncEnumerable> PagedQuery(LdapQue return (false, null); } - private static TimeSpan GetNextBackoff(int retryCount) { + private static TimeSpan GetNextBackoff(int retryCount) + { return TimeSpan.FromSeconds(Math.Min( MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount), MaxBackoffDelay.TotalSeconds)); } private bool CreateSearchRequest(LdapQueryParameters queryParameters, - LdapConnectionWrapper connectionWrapper, out SearchRequest searchRequest) { + LdapConnectionWrapper connectionWrapper, out SearchRequest searchRequest) + { string basePath; - if (!string.IsNullOrWhiteSpace(queryParameters.SearchBase)) { + if (!string.IsNullOrWhiteSpace(queryParameters.SearchBase)) + { basePath = queryParameters.SearchBase; } - else if (!connectionWrapper.GetSearchBase(queryParameters.NamingContext, out basePath)) { + else if (!connectionWrapper.GetSearchBase(queryParameters.NamingContext, out basePath)) + { string tempPath; - if (CallDsGetDcName(queryParameters.DomainName, out var info) && info != null) { + if (CallDsGetDcName(queryParameters.DomainName, out var info) && info != null) + { tempPath = Helpers.DomainNameToDistinguishedName(info.Value.DomainName); connectionWrapper.SaveContext(queryParameters.NamingContext, basePath); } - else if (GetDomain(queryParameters.DomainName, out var domainObject)) { + else if (GetDomain(queryParameters.DomainName, out var domainObject)) + { tempPath = Helpers.DomainNameToDistinguishedName(domainObject.Name); } - else { + else + { searchRequest = null; return false; } - basePath = queryParameters.NamingContext switch { + basePath = queryParameters.NamingContext switch + { NamingContext.Configuration => $"CN=Configuration,{tempPath}", NamingContext.Schema => $"CN=Schema,CN=Configuration,{tempPath}", NamingContext.Default => tempPath, @@ -723,7 +722,8 @@ private bool CreateSearchRequest(LdapQueryParameters queryParameters, connectionWrapper.SaveContext(queryParameters.NamingContext, basePath); - if (!string.IsNullOrWhiteSpace(queryParameters.RelativeSearchBase)) { + if (!string.IsNullOrWhiteSpace(queryParameters.RelativeSearchBase)) + { basePath = $"{queryParameters.RelativeSearchBase},{basePath}"; } } @@ -731,12 +731,15 @@ private bool CreateSearchRequest(LdapQueryParameters queryParameters, searchRequest = new SearchRequest(basePath, queryParameters.LDAPFilter, queryParameters.SearchScope, queryParameters.Attributes); searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope)); - if (queryParameters.IncludeDeleted) { + if (queryParameters.IncludeDeleted) + { searchRequest.Controls.Add(new ShowDeletedControl()); } - if (queryParameters.IncludeSecurityDescriptor) { - searchRequest.Controls.Add(new SecurityDescriptorFlagControl { + if (queryParameters.IncludeSecurityDescriptor) + { + searchRequest.Controls.Add(new SecurityDescriptorFlagControl + { SecurityMasks = SecurityMasks.Dacl | SecurityMasks.Owner }); } @@ -744,7 +747,8 @@ private bool CreateSearchRequest(LdapQueryParameters queryParameters, return true; } - private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControllerInfo? info) { + private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControllerInfo? info) + { if (_dcInfoCache.TryGetValue(domainName.ToUpper().Trim(), out info)) return info != null; var apiResult = _nativeMethods.CallDsGetDcName(null, domainName, @@ -752,7 +756,8 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll NetAPIEnums.DSGETDCNAME_FLAGS.DS_RETURN_DNS_NAME | NetAPIEnums.DSGETDCNAME_FLAGS.DS_DIRECTORY_SERVICE_REQUIRED)); - if (apiResult.IsFailed) { + if (apiResult.IsFailed) + { _dcInfoCache.TryAdd(domainName.ToUpper().Trim(), null); return false; } @@ -761,24 +766,28 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll return true; } - private async Task SetupLdapQuery(LdapQueryParameters queryParameters) { + private async Task SetupLdapQuery(LdapQueryParameters queryParameters) + { var result = new LdapQuerySetupResult(); var (success, connectionWrapper, message) = await _connectionPool.GetLdapConnection(queryParameters.DomainName, queryParameters.GlobalCatalog); - if (!success) { + if (!success) + { result.Success = false; result.Message = $"Unable to create a connection: {message}"; return result; } //This should never happen as far as I know, so just checking for safety - if (connectionWrapper.Connection == null) { + if (connectionWrapper.Connection == null) + { result.Success = false; result.Message = "Connection object is null"; return result; } - if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) { + if (!CreateSearchRequest(queryParameters, connectionWrapper, out var searchRequest)) + { result.Success = false; result.Message = "Failed to create search request"; _connectionPool.ReleaseConnection(connectionWrapper); @@ -794,61 +803,76 @@ private async Task SetupLdapQuery(LdapQueryParameters quer private SearchRequest CreateSearchRequest(string distinguishedName, string ldapFilter, SearchScope searchScope, - string[] attributes) { + string[] attributes) + { var searchRequest = new SearchRequest(distinguishedName, ldapFilter, searchScope, attributes); searchRequest.Controls.Add(new SearchOptionsControl(SearchOption.DomainScope)); return searchRequest; } - public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) { + public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) + { string domainSid; - try { + try + { domainSid = new SecurityIdentifier(sid).AccountDomainSid?.Value.ToUpper(); } - catch { + catch + { var match = _sidRegex.Match(sid); domainSid = match.Success ? match.Groups[1].Value : null; } - if (domainSid == null) { + if (domainSid == null) + { return (false, ""); } - if (Cache.GetDomainSidMapping(domainSid, out var domain)) { + if (Cache.GetDomainSidMapping(domainSid, out var domain)) + { return (true, domain); } - try { + try + { var entry = new DirectoryEntry($"LDAP://"); entry.RefreshCache(new[] { LDAPProperties.DistinguishedName }); var dn = entry.GetProperty(LDAPProperties.DistinguishedName); - if (!string.IsNullOrWhiteSpace(dn)) { + if (!string.IsNullOrWhiteSpace(dn)) + { Cache.AddDomainSidMapping(domainSid, Helpers.DistinguishedNameToDomain(dn)); return (true, Helpers.DistinguishedNameToDomain(dn)); } } - catch { + catch + { //pass } - if (await ConvertDomainSidToDomainNameFromLdap(sid) is (true, var domainName)) { + if (await ConvertDomainSidToDomainNameFromLdap(sid) is (true, var domainName)) + { Cache.AddDomainSidMapping(domainSid, domainName); return (true, domainName); } - using (var ctx = new PrincipalContext(ContextType.Domain)) { - try { + using (var ctx = new PrincipalContext(ContextType.Domain)) + { + try + { var principal = Principal.FindByIdentity(ctx, IdentityType.Sid, sid); - if (principal != null) { + if (principal != null) + { var dn = principal.DistinguishedName; - if (!string.IsNullOrWhiteSpace(dn)) { + if (!string.IsNullOrWhiteSpace(dn)) + { Cache.AddDomainSidMapping(domainSid, Helpers.DistinguishedNameToDomain(dn)); return (true, Helpers.DistinguishedNameToDomain(dn)); } } } - catch { + catch + { //pass } } @@ -856,23 +880,28 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF return (false, string.Empty); } - private async Task<(bool Success, string DomainName)> ConvertDomainSidToDomainNameFromLdap(string domainSid) { - if (!GetDomain(out var domain) || domain?.Name == null) { + private async Task<(bool Success, string DomainName)> ConvertDomainSidToDomainNameFromLdap(string domainSid) + { + if (!GetDomain(out var domain) || domain?.Name == null) + { return (false, string.Empty); } - var result = await Query(new LdapQueryParameters { + var result = await Query(new LdapQueryParameters + { DomainName = domain.Name, Attributes = new[] { LDAPProperties.DistinguishedName }, GlobalCatalog = true, LDAPFilter = new LDAPFilter().AddDomains(CommonFilters.SpecificSID(domainSid)).GetFilter() }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { return (true, Helpers.DistinguishedNameToDomain(result.Value.DistinguishedName)); } - result = await Query(new LdapQueryParameters { + result = await Query(new LdapQueryParameters + { DomainName = domain.Name, Attributes = new[] { LDAPProperties.DistinguishedName }, GlobalCatalog = true, @@ -880,75 +909,90 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF .AddFilter($"(securityidentifier={Helpers.ConvertSidToHexSid(domainSid)})", true).GetFilter() }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { return (true, Helpers.DistinguishedNameToDomain(result.Value.DistinguishedName)); } - result = await Query(new LdapQueryParameters { + result = await Query(new LdapQueryParameters + { DomainName = domain.Name, Attributes = new[] { LDAPProperties.DistinguishedName }, LDAPFilter = new LDAPFilter().AddFilter("(objectclass=domaindns)", true) .AddFilter(CommonFilters.SpecificSID(domainSid), true).GetFilter() }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { return (true, Helpers.DistinguishedNameToDomain(result.Value.DistinguishedName)); } return (false, string.Empty); } - public async Task<(bool Success, string DomainSid)> GetDomainSidFromDomainName(string domainName) { + public async Task<(bool Success, string DomainSid)> GetDomainSidFromDomainName(string domainName) + { if (Cache.GetDomainSidMapping(domainName, out var domainSid)) return (true, domainSid); - try { + try + { var entry = new DirectoryEntry($"LDAP://{domainName}"); //Force load objectsid into the object cache entry.RefreshCache(new[] { "objectSid" }); var sid = entry.GetSid(); - if (sid != null) { + if (sid != null) + { Cache.AddDomainSidMapping(domainName, sid); domainSid = sid; return (true, domainSid); } } - catch { + catch + { //we expect this to fail sometimes } if (GetDomain(domainName, out var domainObject)) - try { + try + { domainSid = domainObject.GetDirectoryEntry().GetSid(); - if (domainSid != null) { + if (domainSid != null) + { Cache.AddDomainSidMapping(domainName, domainSid); return (true, domainSid); } } - catch { + catch + { //we expect this to fail sometimes (not sure why, but better safe than sorry) } foreach (var name in _translateNames) - try { + try + { var account = new NTAccount(domainName, name); var sid = (SecurityIdentifier)account.Translate(typeof(SecurityIdentifier)); domainSid = sid.AccountDomainSid.ToString(); Cache.AddDomainSidMapping(domainName, domainSid); return (true, domainSid); } - catch { + catch + { //We expect this to fail if the username doesn't exist in the domain } - var result = await Query(new LdapQueryParameters() { + var result = await Query(new LdapQueryParameters() + { DomainName = domainName, Attributes = new[] { LDAPProperties.ObjectSID }, LDAPFilter = new LDAPFilter().AddFilter(CommonFilters.DomainControllers, true).GetFilter() }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { var sid = result.Value.GetSid(); - if (!string.IsNullOrEmpty(sid)) { + if (!string.IsNullOrEmpty(sid)) + { domainSid = new SecurityIdentifier(sid).AccountDomainSid.Value; Cache.AddDomainSidMapping(domainName, domainSid); return (true, domainSid); @@ -965,11 +1009,13 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF /// /// /// - public bool GetDomain(string domainName, out Domain domain) { + public bool GetDomain(string domainName, out Domain domain) + { var cacheKey = domainName ?? _nullCacheKey; if (DomainCache.TryGetValue(cacheKey, out domain)) return true; - try { + try + { DirectoryContext context; if (_ldapConfig.Username != null) context = domainName != null @@ -987,16 +1033,19 @@ public bool GetDomain(string domainName, out Domain domain) { DomainCache.TryAdd(cacheKey, domain); return true; } - catch (Exception e) { + catch (Exception e) + { _log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName); return false; } } - public static bool GetDomain(string domainName, LDAPConfig ldapConfig, out Domain domain) { + public static bool GetDomain(string domainName, LDAPConfig ldapConfig, out Domain domain) + { if (DomainCache.TryGetValue(domainName, out domain)) return true; - try { + try + { DirectoryContext context; if (ldapConfig.Username != null) context = domainName != null @@ -1014,7 +1063,8 @@ public static bool GetDomain(string domainName, LDAPConfig ldapConfig, out Domai DomainCache.TryAdd(domainName, domain); return true; } - catch (Exception e) { + catch (Exception e) + { Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName, e.Message); return false; } @@ -1027,11 +1077,13 @@ public static bool GetDomain(string domainName, LDAPConfig ldapConfig, out Domai /// /// /// - public bool GetDomain(out Domain domain) { + public bool GetDomain(out Domain domain) + { var cacheKey = _nullCacheKey; if (DomainCache.TryGetValue(cacheKey, out domain)) return true; - try { + try + { var context = _ldapConfig.Username != null ? new DirectoryContext(DirectoryContextType.Domain, _ldapConfig.Username, _ldapConfig.Password) @@ -1041,34 +1093,41 @@ public bool GetDomain(out Domain domain) { DomainCache.TryAdd(cacheKey, domain); return true; } - catch (Exception e) { + catch (Exception e) + { _log.LogDebug(e, "GetDomain call failed for blank domain"); return false; } } - public async Task<(bool Success, TypedPrincipal Principal)> ResolveAccountName(string name, string domain) { - if (string.IsNullOrWhiteSpace(name)) { + public async Task<(bool Success, TypedPrincipal Principal)> ResolveAccountName(string name, string domain) + { + if (string.IsNullOrWhiteSpace(name)) + { return (false, null); } if (Cache.GetPrefixedValue(name, domain, out var id) && Cache.GetIDType(id, out var type)) - return (true, new TypedPrincipal { + return (true, new TypedPrincipal + { ObjectIdentifier = id, ObjectType = type }); - var result = await Query(new LdapQueryParameters() { + var result = await Query(new LdapQueryParameters() + { DomainName = domain, Attributes = CommonProperties.TypeResolutionProps, LDAPFilter = $"(samaccountname={name})" }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (result.IsSuccess) { + if (result.IsSuccess) + { type = result.Value.GetLabel(); id = result.Value.GetObjectIdentifier(); - if (!string.IsNullOrWhiteSpace(id)) { + if (!string.IsNullOrWhiteSpace(id)) + { Cache.AddPrefixedValue(name, domain, id); Cache.AddType(id, type); } @@ -1080,10 +1139,12 @@ public bool GetDomain(out Domain domain) { return (false, null); } - public async Task<(bool Success, string SecurityIdentifier)> ResolveHostToSid(string host, string domain) { + public async Task<(bool Success, string SecurityIdentifier)> ResolveHostToSid(string host, string domain) + { //Remove SPN prefixes from the host name so we're working with a clean name var strippedHost = Helpers.StripServicePrincipalName(host).ToUpper().TrimEnd('$'); - if (string.IsNullOrEmpty(strippedHost)) { + if (string.IsNullOrEmpty(strippedHost)) + { return (false, string.Empty); } @@ -1091,17 +1152,21 @@ public bool GetDomain(out Domain domain) { //Immediately start with NetWekstaGetInfo as its our most reliable indicator if successful var workstationInfo = await GetWorkstationInfo(strippedHost); - if (workstationInfo.HasValue) { + if (workstationInfo.HasValue) + { var tempName = workstationInfo.Value.ComputerName; var tempDomain = workstationInfo.Value.LanGroup; - if (string.IsNullOrWhiteSpace(tempDomain)) { + if (string.IsNullOrWhiteSpace(tempDomain)) + { tempDomain = domain; } - if (!string.IsNullOrWhiteSpace(tempName)) { + if (!string.IsNullOrWhiteSpace(tempName)) + { tempName = $"{tempName}$".ToUpper(); - if (await ResolveAccountName(tempName, tempDomain) is (true, var principal)) { + if (await ResolveAccountName(tempName, tempDomain) is (true, var principal)) + { _hostResolutionMap.TryAdd(strippedHost, principal.ObjectIdentifier); return (true, principal.ObjectIdentifier); } @@ -1109,10 +1174,13 @@ public bool GetDomain(out Domain domain) { } //Try some socket magic to get the NETBIOS name - if (RequestNETBIOSNameFromComputer(strippedHost, domain, out var netBiosName)) { - if (!string.IsNullOrWhiteSpace(netBiosName)) { + if (RequestNETBIOSNameFromComputer(strippedHost, domain, out var netBiosName)) + { + if (!string.IsNullOrWhiteSpace(netBiosName)) + { var result = await ResolveAccountName($"{netBiosName}$", domain); - if (result.Success) { + if (result.Success) + { _hostResolutionMap.TryAdd(strippedHost, result.Principal.ObjectIdentifier); return (true, result.Principal.ObjectIdentifier); } @@ -1120,52 +1188,62 @@ public bool GetDomain(out Domain domain) { } //Start by handling non-IP address names - if (!IPAddress.TryParse(strippedHost, out _)) { + if (!IPAddress.TryParse(strippedHost, out _)) + { //PRIMARY.TESTLAB.LOCAL - if (strippedHost.Contains(".")) { + if (strippedHost.Contains(".")) + { var split = strippedHost.Split('.'); var name = split[0]; var result = await ResolveAccountName($"{name}$", domain); - if (result.Success) { + if (result.Success) + { _hostResolutionMap.TryAdd(strippedHost, result.Principal.ObjectIdentifier); return (true, result.Principal.ObjectIdentifier); } var tempDomain = string.Join(".", split.Skip(1).ToArray()); result = await ResolveAccountName($"{name}$", tempDomain); - if (result.Success) { + if (result.Success) + { _hostResolutionMap.TryAdd(strippedHost, result.Principal.ObjectIdentifier); return (true, result.Principal.ObjectIdentifier); } } - else { + else + { //Format: WIN10 (probably a netbios name) var result = await ResolveAccountName($"{strippedHost}$", domain); - if (result.Success) { + if (result.Success) + { _hostResolutionMap.TryAdd(strippedHost, result.Principal.ObjectIdentifier); return (true, result.Principal.ObjectIdentifier); } } } - try { + try + { var resolvedHostname = (await Dns.GetHostEntryAsync(strippedHost)).HostName; var split = resolvedHostname.Split('.'); var name = split[0]; var result = await ResolveAccountName($"{name}$", domain); - if (result.Success) { + if (result.Success) + { _hostResolutionMap.TryAdd(strippedHost, result.Principal.ObjectIdentifier); return (true, result.Principal.ObjectIdentifier); } var tempDomain = string.Join(".", split.Skip(1).ToArray()); result = await ResolveAccountName($"{name}$", tempDomain); - if (result.Success) { + if (result.Success) + { _hostResolutionMap.TryAdd(strippedHost, result.Principal.ObjectIdentifier); return (true, result.Principal.ObjectIdentifier); } } - catch { + catch + { //pass } @@ -1177,7 +1255,8 @@ public bool GetDomain(out Domain domain) { /// /// /// - private async Task GetWorkstationInfo(string hostname) { + private async Task GetWorkstationInfo(string hostname) + { if (!await _portScanner.CheckPort(hostname)) return null; @@ -1187,26 +1266,33 @@ public bool GetDomain(out Domain domain) { return null; } - public async Task<(bool Success, string[] Sids)> GetGlobalCatalogMatches(string name, string domain) { - if (Cache.GetGCCache(name, out var matches)) { + public async Task<(bool Success, string[] Sids)> GetGlobalCatalogMatches(string name, string domain) + { + if (Cache.GetGCCache(name, out var matches)) + { return (true, matches); } var sids = new List(); - await foreach (var result in Query(new LdapQueryParameters { - DomainName = domain, - Attributes = new[] { LDAPProperties.ObjectSID }, - GlobalCatalog = true, - LDAPFilter = new LDAPFilter().AddUsers($"(samaccountname={name})").GetFilter() - })) { - if (result.IsSuccess) { + await foreach (var result in Query(new LdapQueryParameters + { + DomainName = domain, + Attributes = new[] { LDAPProperties.ObjectSID }, + GlobalCatalog = true, + LDAPFilter = new LDAPFilter().AddUsers($"(samaccountname={name})").GetFilter() + })) + { + if (result.IsSuccess) + { var sid = result.Value.GetSid(); - if (!string.IsNullOrWhiteSpace(sid)) { + if (!string.IsNullOrWhiteSpace(sid)) + { sids.Add(sid); } } - else { + else + { return (false, Array.Empty()); } } @@ -1216,9 +1302,11 @@ public bool GetDomain(out Domain domain) { } public async Task<(bool Success, TypedPrincipal Principal)> ResolveCertTemplateByProperty(string propertyValue, - string propertyName, string domainName) { + string propertyName, string domainName) + { var filter = new LDAPFilter().AddCertificateTemplates().AddFilter($"({propertyName}={propertyValue})", true); - var result = await Query(new LdapQueryParameters { + var result = await Query(new LdapQueryParameters + { DomainName = domainName, Attributes = CommonProperties.TypeResolutionProps, SearchScope = SearchScope.OneLevel, @@ -1227,7 +1315,8 @@ public bool GetDomain(out Domain domain) { LDAPFilter = filter.GetFilter(), }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); - if (!result.IsSuccess) { + if (!result.IsSuccess) + { _log.LogWarning( "Could not find certificate template with {PropertyName}:{PropertyValue}: {Error}", propertyName, propertyName, result.Error); @@ -1245,10 +1334,12 @@ public bool GetDomain(out Domain domain) { /// /// /// - private static bool RequestNETBIOSNameFromComputer(string server, string domain, out string netbios) { + private static bool RequestNETBIOSNameFromComputer(string server, string domain, out string netbios) + { var receiveBuffer = new byte[1024]; var requestSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); - try { + try + { //Set receive timeout to 1 second requestSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveTimeout, 1000); EndPoint remoteEndpoint; @@ -1258,7 +1349,8 @@ private static bool RequestNETBIOSNameFromComputer(string server, string domain, remoteEndpoint = new IPEndPoint(parsedAddress, 137); else //If its not an IP, we're going to try and resolve it from DNS - try { + try + { IPAddress address; if (server.Contains(".")) address = Dns @@ -1266,14 +1358,16 @@ private static bool RequestNETBIOSNameFromComputer(string server, string domain, else address = Dns.GetHostAddresses($"{server}.{domain}")[0]; - if (address == null) { + if (address == null) + { netbios = null; return false; } remoteEndpoint = new IPEndPoint(address, 137); } - catch { + catch + { //Failed to resolve an IP, so return null netbios = null; return false; @@ -1282,10 +1376,12 @@ private static bool RequestNETBIOSNameFromComputer(string server, string domain, var originEndpoint = new IPEndPoint(IPAddress.Any, 0); requestSocket.Bind(originEndpoint); - try { + try + { requestSocket.SendTo(NameRequest, remoteEndpoint); var receivedByteCount = requestSocket.ReceiveFrom(receiveBuffer, ref remoteEndpoint); - if (receivedByteCount >= 90) { + if (receivedByteCount >= 90) + { netbios = new ASCIIEncoding().GetString(receiveBuffer, 57, 16).Trim('\0', ' '); return true; } @@ -1293,12 +1389,14 @@ private static bool RequestNETBIOSNameFromComputer(string server, string domain, netbios = null; return false; } - catch (SocketException) { + catch (SocketException) + { netbios = null; return false; } } - finally { + finally + { //Make sure we close the socket if its open requestSocket.Close(); } @@ -1308,22 +1406,27 @@ private static bool RequestNETBIOSNameFromComputer(string server, string domain, /// Created for testing purposes /// /// - public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() { + public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() + { return new ActiveDirectorySecurityDescriptor(new ActiveDirectorySecurity()); } public async Task<(bool Success, TypedPrincipal Principal)> ConvertLocalWellKnownPrincipal(SecurityIdentifier sid, - string computerDomainSid, string computerDomain) { + string computerDomainSid, string computerDomain) + { if (!WellKnownPrincipal.GetWellKnownPrincipal(sid.Value, out var common)) return (false, null); //The everyone and auth users principals are special and will be converted to the domain equivalent - if (sid.Value is "S-1-1-0" or "S-1-5-11") { + if (sid.Value is "S-1-1-0" or "S-1-5-11") + { return await GetWellKnownPrincipal(sid.Value, computerDomain); } //Use the computer object id + the RID of the sid we looked up to create our new principal - var principal = new TypedPrincipal { + var principal = new TypedPrincipal + { ObjectIdentifier = $"{computerDomainSid}-{sid.Rid()}", - ObjectType = common.ObjectType switch { + ObjectType = common.ObjectType switch + { Label.User => Label.LocalUser, Label.Group => Label.LocalGroup, _ => common.ObjectType @@ -1333,11 +1436,13 @@ public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() { return (true, principal); } - public async Task IsDomainController(string computerObjectId, string domainName) { + public async Task IsDomainController(string computerObjectId, string domainName) + { var resDomain = await GetDomainNameFromSid(domainName) is (false, var tempDomain) ? tempDomain : domainName; var filter = new LDAPFilter().AddFilter(CommonFilters.SpecificSID(computerObjectId), true) .AddFilter(CommonFilters.DomainControllers, true); - var result = await Query(new LdapQueryParameters() { + var result = await Query(new LdapQueryParameters() + { DomainName = resDomain, Attributes = CommonProperties.ObjectID, LDAPFilter = filter.GetFilter(), @@ -1345,13 +1450,16 @@ public async Task IsDomainController(string computerObjectId, string domai return result is { IsSuccess: true }; } - public async Task<(bool Success, TypedPrincipal Principal)> ResolveDistinguishedName(string distinguishedName) { - if (_distinguishedNameCache.TryGetValue(distinguishedName, out var principal)) { + public async Task<(bool Success, TypedPrincipal Principal)> ResolveDistinguishedName(string distinguishedName) + { + if (_distinguishedNameCache.TryGetValue(distinguishedName, out var principal)) + { return (true, principal); } var domain = Helpers.DistinguishedNameToDomain(distinguishedName); - var result = await Query(new LdapQueryParameters { + var result = await Query(new LdapQueryParameters + { DomainName = domain, Attributes = CommonProperties.TypeResolutionProps, SearchBase = distinguishedName, @@ -1359,14 +1467,17 @@ public async Task IsDomainController(string computerObjectId, string domai LDAPFilter = new LDAPFilter().AddAllObjects().GetFilter() }).DefaultIfEmpty(null).FirstOrDefaultAsync(); - if (result is { IsSuccess: true }) { + if (result is { IsSuccess: true }) + { var entry = result.Value; var id = entry.GetObjectIdentifier(); - if (id == null) { + if (id == null) + { return (false, default); } - if (await GetWellKnownPrincipal(id, domain) is (true, var wellKnownPrincipal)) { + if (await GetWellKnownPrincipal(id, domain) is (true, var wellKnownPrincipal)) + { _distinguishedNameCache.TryAdd(distinguishedName, wellKnownPrincipal); return (true, wellKnownPrincipal); } @@ -1377,28 +1488,33 @@ public async Task IsDomainController(string computerObjectId, string domai return (true, principal); } - using (var ctx = new PrincipalContext(ContextType.Domain)) { - try { + using (var ctx = new PrincipalContext(ContextType.Domain)) + { + try + { var lookupPrincipal = Principal.FindByIdentity(ctx, IdentityType.DistinguishedName, distinguishedName); if (lookupPrincipal != null && - ((DirectoryEntry)lookupPrincipal.GetUnderlyingObject()).GetTypedPrincipal(out principal)) { + ((DirectoryEntry)lookupPrincipal.GetUnderlyingObject()).GetTypedPrincipal(out principal)) + { return (true, principal); } return (false, default); } - catch { + catch + { return (false, default); } } } - + public void AddDomainController(string domainControllerSID) { DomainControllers.TryAdd(domainControllerSID, new byte()); } - public async IAsyncEnumerable GetWellKnownPrincipalOutput() { + public async IAsyncEnumerable GetWellKnownPrincipalOutput() + { foreach (var wkp in SeenWellKnownPrincipals) { WellKnownPrincipal.GetWellKnownPrincipal(wkp.Value.WkpId, out var principal); @@ -1416,71 +1532,87 @@ public async IAsyncEnumerable GetWellKnownPrincipalOutput() { }; output.Properties.Add("name", $"{principal.ObjectIdentifier}@{wkp.Value.DomainName}".ToUpper()); - if (await GetDomainSidFromDomainName(wkp.Value.DomainName) is (true, var sid)) { - output.Properties.Add("domainsid", sid); + if (await GetDomainSidFromDomainName(wkp.Value.DomainName) is (true, var sid)) + { + output.Properties.Add("domainsid", sid); } - + output.Properties.Add("domain", wkp.Value.DomainName.ToUpper()); output.ObjectIdentifier = wkp.Key; yield return output; } } - public void SetLdapConfig(LDAPConfig config) { + public void SetLdapConfig(LDAPConfig config) + { _ldapConfig = config; _connectionPool.Dispose(); _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); } - public Task<(bool Success, string Message)> TestLdapConnection(string domain) { + public Task<(bool Success, string Message)> TestLdapConnection(string domain) + { return _connectionPool.TestDomainConnection(domain, false); } - public async Task<(bool Success, string Path)> GetNamingContextPath(string domain, NamingContext context) { - if (await _connectionPool.GetLdapConnection(domain, false) is (true, var wrapper, _)) { + public async Task<(bool Success, string Path)> GetNamingContextPath(string domain, NamingContext context) + { + if (await _connectionPool.GetLdapConnection(domain, false) is (true, var wrapper, _)) + { _connectionPool.ReleaseConnection(wrapper); - if (wrapper.GetSearchBase(context, out var searchBase)) { + if (wrapper.GetSearchBase(context, out var searchBase)) + { return (true, searchBase); } } - - var property = context switch { + + var property = context switch + { NamingContext.Default => LDAPProperties.DefaultNamingContext, NamingContext.Configuration => LDAPProperties.ConfigurationNamingContext, NamingContext.Schema => LDAPProperties.SchemaNamingContext, _ => throw new ArgumentOutOfRangeException(nameof(context), context, null) }; - try { + try + { var entry = CreateDirectoryEntry($"LDAP://{domain}/RootDSE"); entry.RefreshCache(new[] { property }); var searchBase = entry.GetProperty(property); - if (!string.IsNullOrWhiteSpace(searchBase)) { + if (!string.IsNullOrWhiteSpace(searchBase)) + { return (true, searchBase); } } - catch { + catch + { //pass } - if (GetDomain(domain, out var domainObj)) { - try { + if (GetDomain(domain, out var domainObj)) + { + try + { var entry = domainObj.GetDirectoryEntry(); entry.RefreshCache(new[] { property }); var searchBase = entry.GetProperty(property); - if (!string.IsNullOrWhiteSpace(searchBase)) { + if (!string.IsNullOrWhiteSpace(searchBase)) + { return (true, searchBase); } } - catch { + catch + { //pass } var name = domainObj.Name; - if (!string.IsNullOrWhiteSpace(name)) { + if (!string.IsNullOrWhiteSpace(name)) + { var tempPath = Helpers.DomainNameToDistinguishedName(name); - - var searchBase = context switch { + + var searchBase = context switch + { NamingContext.Configuration => $"CN=Configuration,{tempPath}", NamingContext.Schema => $"CN=Schema,CN=Configuration,{tempPath}", NamingContext.Default => tempPath, @@ -1494,24 +1626,29 @@ public void SetLdapConfig(LDAPConfig config) { return (false, default); } - private DirectoryEntry CreateDirectoryEntry(string path) { - if (_ldapConfig.Username != null) { + private DirectoryEntry CreateDirectoryEntry(string path) + { + if (_ldapConfig.Username != null) + { return new DirectoryEntry(path, _ldapConfig.Username, _ldapConfig.Password); } return new DirectoryEntry(path); } - public void Dispose() { + public void Dispose() + { _connectionPool?.Dispose(); } - - internal static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType, string[] objectClasses, int flags, out Label type) { + + internal static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType, string[] objectClasses, int flags, out Label type) + { type = Label.Base; - if (objectIdentifier != null && WellKnownPrincipal.GetWellKnownPrincipal(objectIdentifier, out var principal)) { + if (objectIdentifier != null && WellKnownPrincipal.GetWellKnownPrincipal(objectIdentifier, out var principal)) + { type = principal.ObjectType; return true; } - + //Override GMSA/MSA account to treat them as users for the graph if (objectClasses != null && (objectClasses.Contains(MSAClass, StringComparer.OrdinalIgnoreCase) || objectClasses.Contains(GMSAClass, StringComparer.OrdinalIgnoreCase))) @@ -1520,19 +1657,22 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN return true; } - if (samAccountType != null) { + if (samAccountType != null) + { var objectType = Helpers.SamAccountTypeToType(samAccountType); - if (objectType != Label.Base) { + if (objectType != Label.Base) + { type = objectType; return true; } } - if (objectClasses == null) { + if (objectClasses == null) + { type = Label.Base; return false; } - + if (objectClasses.Contains(GroupPolicyContainerClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.GPO; else if (objectClasses.Contains(OrganizationalUnitClass, StringComparer.InvariantCultureIgnoreCase)) @@ -1547,14 +1687,17 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN type = Label.CertTemplate; else if (objectClasses.Contains(PKIEnrollmentServiceClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.EnterpriseCA; - else if (objectClasses.Contains(CertificationAuthorityClass, StringComparer.InvariantCultureIgnoreCase)) { + else if (objectClasses.Contains(CertificationAuthorityClass, StringComparer.InvariantCultureIgnoreCase)) + { if (distinguishedName.Contains(DirectoryPaths.RootCALocation)) type = Label.RootCA; if (distinguishedName.Contains(DirectoryPaths.AIACALocation)) type = Label.AIACA; if (distinguishedName.Contains(DirectoryPaths.NTAuthStoreLocation)) type = Label.NTAuthStore; - }else if (objectClasses.Contains(OIDContainerClass, StringComparer.InvariantCultureIgnoreCase)) { + } + else if (objectClasses.Contains(OIDContainerClass, StringComparer.InvariantCultureIgnoreCase)) + { if (distinguishedName.StartsWith(DirectoryPaths.OIDContainerLocation, StringComparison.InvariantCultureIgnoreCase)) type = Label.Container; @@ -1567,6 +1710,14 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN return type != Label.Base; } + private class QueryResult + { + public bool Success { get; set; } + public SearchResponse Response { get; set; } + public string ErrorMessage { get; set; } + public int ErrorCode { get; set; } + } + private const string GroupPolicyContainerClass = "groupPolicyContainer"; private const string OrganizationalUnitClass = "organizationalUnit"; private const string DomainClass = "domain";