From 6d5cd65563348c4194bc6aa1c19d87ef9e51a4bb Mon Sep 17 00:00:00 2001 From: rvazarkar Date: Tue, 18 Jun 2024 16:21:51 -0400 Subject: [PATCH] wip: implement connection pool --- src/CommonLib/LDAPUtilsNew.cs | 284 ++++++++++++---- src/CommonLib/LdapConnectionPool.cs | 376 ++++++++++++++++++++++ src/CommonLib/LdapConnectionWrapperNew.cs | 23 +- src/CommonLib/LdapQuerySetupResult.cs | 12 + 4 files changed, 632 insertions(+), 63 deletions(-) create mode 100644 src/CommonLib/LdapConnectionPool.cs create mode 100644 src/CommonLib/LdapQuerySetupResult.cs diff --git a/src/CommonLib/LDAPUtilsNew.cs b/src/CommonLib/LDAPUtilsNew.cs index 28d842bd..d8677da5 100644 --- a/src/CommonLib/LDAPUtilsNew.cs +++ b/src/CommonLib/LDAPUtilsNew.cs @@ -40,48 +40,179 @@ public class LDAPUtilsNew { private readonly object _lockObj = new(); private readonly ManualResetEvent _connectionResetEvent = new(false); - public async IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, + public async IAsyncEnumerable> Query(LdapQueryParameters queryParameters, [EnumeratorCancellation] CancellationToken cancellationToken = new()) { - //Always force create a new connection - var (success, connectionWrapper, message) = await GetLdapConnection(queryParameters.DomainName, - queryParameters.GlobalCatalog, true); - if (!success) { - _log.LogDebug("PagedQuery failure: unable to create a connection: {Reason}\n{Info}", message, + var setupResult = await SetupLdapQuery(queryParameters, true); + + if (!setupResult.Success) { + _log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message, queryParameters.GetQueryInfo()); - yield return new LdapResult { - Error = $"Unable to create a connection: {message}", - QueryInfo = queryParameters.GetQueryInfo() - }; yield break; } - //This should never happen as far as I know, so just checking for safety - if (connectionWrapper == null) { - _log.LogError("PagedQuery failure: ldap connection is null\n{Info}", queryParameters.GetQueryInfo()); - yield return new LdapResult { - Error = "Connection is null", - QueryInfo = queryParameters.GetQueryInfo() - }; + var searchRequest = setupResult.SearchRequest; + var connectionWrapper = setupResult.ConnectionWrapper; + var connection = connectionWrapper.Connection; + var serverName = setupResult.Server; + + if (serverName == null) { + _log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible"); + } + + if (cancellationToken.IsCancellationRequested) { yield break; } - var connection = connectionWrapper.Connection; + var queryRetryCount = 0; + var busyRetryCount = 0; + LdapResult tempResult = null; + while (queryRetryCount < MaxRetries) { + SearchResponse response = null; + try { + _log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo()); + response = (SearchResponse)connection.SendRequest(searchRequest); + } + catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) { + /* + * If we dont have a servername, we're not going to be able to re-establish a connection here. Page cookies are only valid for the server they were generated on. Bail out. + */ + if (serverName == null) { + _log.LogError( + "Query - Received server down exception without a known servername. Unable to generate new connection\n{Info}", + queryParameters.GetQueryInfo()); + yield break; + } + + /*A ServerDown exception indicates that our connection is no longer valid for one of many reasons. + However, this function is generally called by multiple threads, so we need to be careful in recreating + the connection. Using a semaphore, we can ensure that only one thread is actually recreating the connection + while the other threads that hit the ServerDown exception simply wait. The initial caller will hold the semaphore + and do a backoff delay before trying to make a new connection which will replace the existing connection in the + _ldapConnections cache. Other threads will retrieve the new connection from the cache instead of making a new one + This minimizes overhead of new connections while still fixing our core problem.*/ + + //Increment our query retry count + queryRetryCount++; + + //Attempt to acquire a lock + if (Monitor.TryEnter(_lockObj)) { + //Signal the reset event to ensure no everyone else waits + _connectionResetEvent.Reset(); + try { + //Try up to MaxRetries time to make a new connection, ensuring we're not using the cache + for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { + var backoffDelay = GetNextBackoff(retryCount); + await Task.Delay(backoffDelay, cancellationToken); + var (success, newConnectionWrapper, _) = await GetLdapConnection(queryParameters, true); + if (success) { + newConnectionWrapper.CopyContexts(connectionWrapper); + connectionWrapper.Connection.Dispose(); + connectionWrapper = newConnectionWrapper; + break; + } + + if (retryCount == MaxRetries - 1) { + _log.LogError("Query - Failed to get a new connection after ServerDown.\n{Info}", + queryParameters.GetQueryInfo()); + yield break; + } + } + }finally{ + _connectionResetEvent.Set(); + Monitor.Exit(_lockObj); + } + } + else { + //If someone else is holding the reset event, we want to just wait and then pull the newly created connection out of the cache + //This event will be released after the first entrant thread is done making a new connection + //The thread.sleep is to prevent a potential, very unlikely race + Thread.Sleep(50); + _connectionResetEvent.WaitOne(); + + //At this point, our connection reset event should be tripped, and there should be a new connection on the cache + var (success, newConnectionWrapper, _) = await GetLdapConnection(queryParameters); + if (!success) { + _log.LogError("Query - Failed to recover from ServerDown error\n{Info}", queryParameters.GetQueryInfo()); + yield break; + } + + newConnectionWrapper.CopyContexts(connectionWrapper); + connectionWrapper = newConnectionWrapper; + connection = connectionWrapper.Connection; + } + } + catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { + /* + * If we get a busy error, we want to do an exponential backoff, but maintain the current connection + * The expectation is that given enough time, the server should stop being busy and service our query appropriately + */ + busyRetryCount++; + var backoffDelay = GetNextBackoff(busyRetryCount); + await Task.Delay(backoffDelay, cancellationToken); + } + catch (LdapException le) { + //No point in printing local exceptions because they're literally worthless + tempResult = new LdapResult() { + Error = + $"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", + QueryInfo = queryParameters.GetQueryInfo() + }; + } + catch (Exception e) { + tempResult = new LdapResult { + Error = + $"PagedQuery - Caught unrecoverable exception: {e.Message}", + QueryInfo = queryParameters.GetQueryInfo() + }; + } - //Pull the server name from the connection for retry logic later - if (!connectionWrapper.GetServer(out var serverName)) { - _log.LogDebug("PagedQuery: Failed to get server value"); - serverName = null; + if (tempResult != null) { + yield return tempResult; + yield break; + } } + } - if (!CreateSearchRequest(queryParameters, ref connectionWrapper, out var searchRequest)) { - _log.LogError("PagedQuery failure: unable to resolve search base\n{Info}", queryParameters.GetQueryInfo()); - yield return new LdapResult { - Error = "Unable to create search request", - QueryInfo = queryParameters.GetQueryInfo() - }; + private bool SendSearchRequestWithRetryHandling(LdapConnectionWrapperNew connectionWrapper, + SearchRequest searchRequest, out SearchResponse searchResponse) { + for (var retryCount = 0; retryCount < MaxRetries; retryCount++) { + try { + searchResponse = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest); + return true; + } + catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) { + + } + } + } + + private bool SendPagedSearchRequestWithRetryHandling(LdapConnectionWrapperNew connectionWrapper, + SearchRequest searchRequest, out SearchResponse searchResponse) { + + + + + } + + public async IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, + [EnumeratorCancellation] CancellationToken cancellationToken = new()) { + var setupResult = await SetupLdapQuery(queryParameters, true); + + if (!setupResult.Success) { + _log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message, + queryParameters.GetQueryInfo()); yield break; } + var searchRequest = setupResult.SearchRequest; + var connectionWrapper = setupResult.ConnectionWrapper; + var connection = connectionWrapper.Connection; + var serverName = setupResult.Server; + + if (serverName == null) { + _log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible"); + } + var pageControl = new PageResultRequestControl(500); searchRequest.Controls.Add(pageControl); @@ -188,17 +319,16 @@ public async IAsyncEnumerable> PagedQuery(LdapQue Value = entry }; } - + if (pageResponse.Cookie.Length == 0 || response.Entries.Count == 0 || cancellationToken.IsCancellationRequested) yield break; - + pageControl.Cookie = pageResponse.Cookie; } } - - private static TimeSpan GetNextBackoff(int retryCount) - { + + private static TimeSpan GetNextBackoff(int retryCount) { return TimeSpan.FromSeconds(Math.Min( MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount), MaxBackoffDelay.TotalSeconds)); @@ -303,6 +433,44 @@ private bool } } + private async Task SetupLdapQuery(LdapQueryParameters queryParameters, + bool forceNewConnection = false) { + var result = new LdapQuerySetupResult(); + var (success, connectionWrapper, message) = await GetLdapConnection(queryParameters, forceNewConnection); + if (!success) { + result.Success = false; + result.Message = $"Unable to create a connection: {message}"; + return result; + } + + //This should never happen as far as I know, so just checking for safety + if (connectionWrapper.Connection == null) { + result.Success = false; + result.Message = $"Connection object is null"; + return result; + } + + if (!CreateSearchRequest(queryParameters, ref connectionWrapper, out var searchRequest)) { + result.Success = false; + result.Message = "Failed to create search request"; + return result; + } + + if (GetServerFromConnection(connectionWrapper.Connection, out var server)) { + result.Server = server; + } + + result.Success = true; + result.SearchRequest = searchRequest; + result.ConnectionWrapper = connectionWrapper; + return result; + } + + private Task<(bool Success, LdapConnectionWrapperNew Connection, string Message )> GetLdapConnection( + LdapQueryParameters queryParameters, bool forceCreateNewConnection = false) { + return GetLdapConnection(queryParameters.DomainName, queryParameters.GlobalCatalog, forceCreateNewConnection); + } + private async Task<(bool Success, LdapConnectionWrapperNew Connection, string Message )> GetLdapConnection( string domainName, bool globalCatalog = false, bool forceCreateNewConnection = false) { @@ -442,7 +610,8 @@ private bool } } - private async Task<(bool success, LdapConnectionWrapperNew connection)> CreateLDAPConnectionWithPortCheck(string target, bool globalCatalog) { + private async Task<(bool success, LdapConnectionWrapperNew connection)> CreateLDAPConnectionWithPortCheck( + string target, bool globalCatalog) { if (globalCatalog) { if (await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(true)) || (!_ldapConfig.ForceSSL && await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(false)))) @@ -457,20 +626,18 @@ await _portScanner.CheckPort(target, _ldapConfig.GetPort(false)))) return (false, null); } - private LdapConnectionWrapperNew CheckCacheConnection(LdapConnectionWrapperNew connectionWrapper, string domainName, bool globalCatalog, bool forceCreateNewConnection) - { + private LdapConnectionWrapperNew CheckCacheConnection(LdapConnectionWrapperNew connectionWrapper, string domainName, + bool globalCatalog, bool forceCreateNewConnection) { string cacheIdentifier; - if (_ldapConfig.Server != null) - { + if (_ldapConfig.Server != null) { cacheIdentifier = _ldapConfig.Server; } - else - { - if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier)) - { + else { + if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier)) { //This is kinda gross, but its another way to get the correct domain sid - if (!connectionWrapper.Connection.GetNamingContextSearchBase(NamingContext.Default, out var searchBase) || !GetDomainSidFromConnection(connectionWrapper.Connection, searchBase, out cacheIdentifier)) - { + if (!connectionWrapper.Connection.GetNamingContextSearchBase(NamingContext.Default, + out var searchBase) || !GetDomainSidFromConnection(connectionWrapper.Connection, searchBase, + out cacheIdentifier)) { /* * If we get here, we couldn't resolve a domain sid, which is hella bad, but we also want to keep from creating a shitton of new connections * Cache using the domainname and pray it all works out @@ -479,27 +646,22 @@ private LdapConnectionWrapperNew CheckCacheConnection(LdapConnectionWrapperNew c } } } - - if (forceCreateNewConnection) - { + + if (forceCreateNewConnection) { return _ldapConnectionCache.AddOrUpdate(cacheIdentifier, globalCatalog, connectionWrapper); } return _ldapConnectionCache.TryAdd(cacheIdentifier, globalCatalog, connectionWrapper); } - - private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnectionWrapperNew connection) - { + + private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnectionWrapperNew connection) { //If server is set via our config, we'll always just use this as the cache key - if (_ldapConfig.Server != null) - { + if (_ldapConfig.Server != null) { return _ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection); } - - if (GetDomainSidFromDomainName(domain, out var domainSid)) - { - if (_ldapConnectionCache.TryGet(domainSid, globalCatalog, out connection)) - { + + if (GetDomainSidFromDomainName(domain, out var domainSid)) { + if (_ldapConnectionCache.TryGet(domainSid, globalCatalog, out connection)) { return true; } } @@ -599,10 +761,8 @@ private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl connection.SessionOptions.ReferralChasing = ReferralChasingOptions.None; if (ssl) connection.SessionOptions.SecureSocketLayer = true; - if (_ldapConfig.DisableSigning) { - connection.SessionOptions.Sealing = false; - connection.SessionOptions.Signing = false; - } + connection.SessionOptions.Sealing = !_ldapConfig.DisableSigning; + connection.SessionOptions.Signing = !_ldapConfig.DisableSigning; if (_ldapConfig.DisableCertVerification) connection.SessionOptions.VerifyServerCertificate = (_, _) => true; @@ -683,7 +843,7 @@ private bool TestLdapConnection(LdapConnection connection, string identifier, ou return true; } - private SearchRequest CreateSearchRequest(string distinguishedName, string ldapFilter, SearchScope searchScope, + public static SearchRequest CreateSearchRequest(string distinguishedName, string ldapFilter, SearchScope searchScope, string[] attributes) { var searchRequest = new SearchRequest(distinguishedName, ldapFilter, searchScope, attributes); diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs new file mode 100644 index 00000000..c85bdd17 --- /dev/null +++ b/src/CommonLib/LdapConnectionPool.cs @@ -0,0 +1,376 @@ +using System; +using System.Collections.Concurrent; +using System.DirectoryServices.ActiveDirectory; +using System.DirectoryServices.Protocols; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using SharpHoundCommonLib.Enums; +using SharpHoundCommonLib.Exceptions; +using SharpHoundCommonLib.LDAPQueries; +using SharpHoundCommonLib.Processors; +using SharpHoundRPC.NetAPINative; + +namespace SharpHoundCommonLib; + +public class LdapConnectionPool : IDisposable{ + private readonly ConcurrentBag _connections; + private readonly ConcurrentBag _globalCatalogConnection; + private static readonly ConcurrentDictionary DomainCache = new(); + private readonly SemaphoreSlim _semaphore; + private readonly string _identifier; + private readonly LDAPConfig _ldapConfig; + private readonly ILogger _log; + private readonly PortScanner _portScanner; + private readonly NativeMethods _nativeMethods; + + public LdapConnectionPool(string identifier, LDAPConfig config, int maxConnections = 10, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { + _connections = new ConcurrentBag(); + _globalCatalogConnection = new ConcurrentBag(); + _semaphore = new SemaphoreSlim(maxConnections, maxConnections); + _identifier = identifier; + _ldapConfig = config; + _log = log ?? Logging.LogProvider.CreateLogger("LdapConnectionPool"); + _portScanner = scanner ?? new PortScanner(); + _nativeMethods = nativeMethods ?? new NativeMethods(); + } + + public async Task<(bool Success, LdapConnectionWrapperNew connectionWrapper, string Message)> GetConnectionAsync() { + await _semaphore.WaitAsync(); + if (!_connections.TryTake(out var connectionWrapper)) { + var (success, connection, message) = await CreateNewConnection(); + if (!success) { + return (false, null, message); + } + + connectionWrapper = connection; + } + + return (true, connectionWrapper, null); + } + + public async Task<(bool Success, LdapConnectionWrapperNew connectionWrapper, string Message)> + GetConnectionForSpecificServerAsync(string server, bool globalCatalog) { + await _semaphore.WaitAsync(); + + return CreateNewConnectionForServer(server, globalCatalog); + } + + public async Task<(bool Success, LdapConnectionWrapperNew connectionWrapper, string Message)> GetGlobalCatalogConnectionAsync() { + await _semaphore.WaitAsync(); + if (!_globalCatalogConnection.TryTake(out var connectionWrapper)) { + var (success, connection, message) = await CreateNewConnection(true); + if (!success) { + return (false, null, message); + } + + connectionWrapper = connection; + } + + return (true, connectionWrapper, null); + } + + public void ReleaseConnection(LdapConnectionWrapperNew connectionWrapper, bool returnToPool = true) { + if (returnToPool) { + if (connectionWrapper.GlobalCatalog) { + _globalCatalogConnection.Add(connectionWrapper); + } + else { + _connections.Add(connectionWrapper); + } + } + else { + connectionWrapper.Connection.Dispose(); + } + + _semaphore.Release(); + } + + public void Dispose() { + while (_connections.TryTake(out var wrapper)) { + wrapper.Connection.Dispose(); + } + } + + private async Task<(bool Success, LdapConnectionWrapperNew Connection, string Message)> CreateNewConnection(bool globalCatalog = false) { + if (!string.IsNullOrWhiteSpace(_ldapConfig.Server)) { + return CreateNewConnectionForServer(_ldapConfig.Server, globalCatalog); + } + + if (CreateLdapConnection(_identifier.ToUpper().Trim(), globalCatalog, out var connectionWrapper)) { + _log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1", _identifier); + return (true, connectionWrapper, ""); + } + + string tempDomainName; + + var dsGetDcNameResult = _nativeMethods.CallDsGetDcName(null, _identifier, + (uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY | + NetAPIEnums.DSGETDCNAME_FLAGS.DS_RETURN_DNS_NAME | + NetAPIEnums.DSGETDCNAME_FLAGS.DS_DIRECTORY_SERVICE_REQUIRED)); + if (dsGetDcNameResult.IsSuccess) { + tempDomainName = dsGetDcNameResult.Value.DomainName; + + if (!tempDomainName.Equals(_identifier, StringComparison.OrdinalIgnoreCase) && + CreateLdapConnection(tempDomainName, globalCatalog, out connectionWrapper)) { + _log.LogDebug( + "Successfully created ldap connection for domain: {Domain} using strategy 2 with name {NewName}", + _identifier, tempDomainName); + return (true, connectionWrapper, ""); + } + + var server = dsGetDcNameResult.Value.DomainControllerName.TrimStart('\\'); + + var result = + await CreateLDAPConnectionWithPortCheck(server, globalCatalog); + if (result.success) { + _log.LogDebug( + "Successfully created ldap connection for domain: {Domain} using strategy 3 to server {Server}", + _identifier, server); + return (true, result.connection, ""); + } + } + + if (!GetDomain(_identifier, out var domainObject) || domainObject.Name == null) { + //If we don't get a result here, we effectively have no other ways to resolve this domain, so we'll just have to exit out + _log.LogDebug( + "Could not get domain object from GetDomain, unable to create ldap connection for domain {Domain}", + _identifier); + return (false, null, "Unable to get domain object for further strategies"); + } + tempDomainName = domainObject.Name.ToUpper().Trim(); + + if (!tempDomainName.Equals(_identifier, StringComparison.OrdinalIgnoreCase) && + CreateLdapConnection(tempDomainName, globalCatalog, out connectionWrapper)) { + _log.LogDebug( + "Successfully created ldap connection for domain: {Domain} using strategy 4 with name {NewName}", + _identifier, tempDomainName); + return (true, connectionWrapper, ""); + } + + var primaryDomainController = domainObject.PdcRoleOwner.Name; + var portConnectionResult = + await CreateLDAPConnectionWithPortCheck(primaryDomainController, globalCatalog); + if (portConnectionResult.success) { + _log.LogDebug( + "Successfully created ldap connection for domain: {Domain} using strategy 5 with to pdc {Server}", + _identifier, primaryDomainController); + return (true, connectionWrapper, ""); + } + + foreach (DomainController dc in domainObject.DomainControllers) { + portConnectionResult = + await CreateLDAPConnectionWithPortCheck(primaryDomainController, globalCatalog); + if (portConnectionResult.success) { + _log.LogDebug( + "Successfully created ldap connection for domain: {Domain} using strategy 6 with to pdc {Server}", + _identifier, primaryDomainController); + return (true, connectionWrapper, ""); + } + } + + return (false, null, "All attempted connections failed"); + } + + private (bool Success, LdapConnectionWrapperNew Connection, string Message ) CreateNewConnectionForServer(string identifier, bool globalCatalog = false) { + if (CreateLdapConnection(identifier, globalCatalog, out var serverConnection)) { + return (true, serverConnection, ""); + } + + return (false, null, $"Failed to create ldap connection for {identifier}"); + } + + private bool CreateLdapConnection(string target, bool globalCatalog, + out LdapConnectionWrapperNew connection) { + var baseConnection = CreateBaseConnection(target, true, globalCatalog); + if (TestLdapConnection(baseConnection, out var result)) { + connection = new LdapConnectionWrapperNew(baseConnection, result.SearchResultEntry, globalCatalog, _identifier); + return true; + } + + try { + baseConnection.Dispose(); + } + catch { + //this is just in case + } + + if (_ldapConfig.ForceSSL) { + connection = null; + return false; + } + + baseConnection = CreateBaseConnection(target, false, globalCatalog); + if (TestLdapConnection(baseConnection, out result)) { + connection = new LdapConnectionWrapperNew(baseConnection, result.SearchResultEntry, globalCatalog, _identifier); + return true; + } + + try { + baseConnection.Dispose(); + } + catch { + //this is just in case + } + + connection = null; + return false; + } + + private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl, + bool globalCatalog) { + var port = globalCatalog ? _ldapConfig.GetGCPort(ssl) : _ldapConfig.GetPort(ssl); + var identifier = new LdapDirectoryIdentifier(directoryIdentifier, port, false, false); + var connection = new LdapConnection(identifier) { Timeout = new TimeSpan(0, 0, 5, 0) }; + + //These options are important! + connection.SessionOptions.ProtocolVersion = 3; + //Referral chasing does not work with paged searches + connection.SessionOptions.ReferralChasing = ReferralChasingOptions.None; + if (ssl) connection.SessionOptions.SecureSocketLayer = true; + + connection.SessionOptions.Sealing = !_ldapConfig.DisableSigning; + connection.SessionOptions.Signing = !_ldapConfig.DisableSigning; + + if (_ldapConfig.DisableCertVerification) + connection.SessionOptions.VerifyServerCertificate = (_, _) => true; + + if (_ldapConfig.Username != null) { + var cred = new NetworkCredential(_ldapConfig.Username, _ldapConfig.Password); + connection.Credential = cred; + } + + connection.AuthType = _ldapConfig.AuthType; + + return connection; + } + + /// + /// Tests whether an LDAP connection is working + /// + /// The ldap connection object to test + /// The results fo the connection test + /// True if connection was successful, false otherwise + /// Something is wrong with the supplied credentials + /// + /// A connection "succeeded" but no data was returned. This can be related to + /// kerberos auth across trusts or just simply lack of permissions + /// + private bool TestLdapConnection(LdapConnection connection, out LdapConnectionTestResult testResult) { + testResult = new LdapConnectionTestResult(); + try { + //Attempt an initial bind. If this fails, likely auth is invalid, or its not a valid target + connection.Bind(); + } + catch (LdapException e) { + //TODO: Maybe look at this and find a better way? + if (e.ErrorCode is (int)LdapErrorCodes.InvalidCredentials or (int)ResultCode.InappropriateAuthentication) { + connection.Dispose(); + throw new LdapAuthenticationException(e); + } + + testResult.Message = e.Message; + testResult.ErrorCode = e.ErrorCode; + return false; + } + catch (Exception e) { + testResult.Message = e.Message; + return false; + } + + SearchResponse response; + try { + //Do an initial search request to get the rootDSE + //This ldap filter is equivalent to (objectclass=*) + var searchRequest = LDAPUtilsNew.CreateSearchRequest("", new LDAPFilter().AddAllObjects().GetFilter(), + SearchScope.Base, null); + + response = (SearchResponse)connection.SendRequest(searchRequest); + } + catch (LdapException e) { + /* + * If we can't send the initial search request, its unlikely any other search requests will work so we will immediately return false + */ + testResult.Message = e.Message; + testResult.ErrorCode = e.ErrorCode; + return false; + } + + if (response?.Entries == null || response.Entries.Count == 0) { + /* + * This can happen for one of two reasons, either we dont have permission to query AD or we're authenticating + * across external trusts with kerberos authentication without Forest Search Order properly configured. + * Either way, this connection isn't useful for us because we're not going to get data, so return false + */ + + connection.Dispose(); + throw new NoLdapDataException(); + } + + testResult.SearchResultEntry = new SearchResultEntryWrapper(response.Entries[0]); + testResult.Message = ""; + return true; + } + + private class LdapConnectionTestResult { + public string Message { get; set; } + public ISearchResultEntry SearchResultEntry { get; set; } + public int ErrorCode { get; set; } + } + + private async Task<(bool success, LdapConnectionWrapperNew connection)> CreateLDAPConnectionWithPortCheck( + string target, bool globalCatalog) { + if (globalCatalog) { + if (await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(true)) || (!_ldapConfig.ForceSSL && + await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(false)))) + return (CreateLdapConnection(target, true, out var connection), connection); + } + else { + if (await _portScanner.CheckPort(target, _ldapConfig.GetPort(true)) || (!_ldapConfig.ForceSSL && + await _portScanner.CheckPort(target, _ldapConfig.GetPort(false)))) + return (CreateLdapConnection(target, true, out var connection), connection); + } + + return (false, null); + } + + /// + /// Attempts to get the Domain object representing the target domain. If null is specified for the domain name, gets + /// the user's current domain + /// + /// + /// + /// + private bool GetDomain(string domainName, out Domain domain) { + var cacheKey = domainName; + if (DomainCache.TryGetValue(cacheKey, out domain)) return true; + + try { + DirectoryContext context; + if (_ldapConfig.Username != null) + context = domainName != null + ? new DirectoryContext(DirectoryContextType.Domain, domainName, _ldapConfig.Username, + _ldapConfig.Password) + : new DirectoryContext(DirectoryContextType.Domain, _ldapConfig.Username, + _ldapConfig.Password); + else + context = domainName != null + ? new DirectoryContext(DirectoryContextType.Domain, domainName) + : new DirectoryContext(DirectoryContextType.Domain); + + domain = Domain.GetDomain(context); + if (domain == null) return false; + DomainCache.TryAdd(cacheKey, domain); + return true; + } + catch (Exception e) { + _log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName); + return false; + } + } +} + +//TESTLAB +//TESTLAB.LOCAL +//PRIMARY.TESTLAB.LOCAL \ No newline at end of file diff --git a/src/CommonLib/LdapConnectionWrapperNew.cs b/src/CommonLib/LdapConnectionWrapperNew.cs index 8a87ebdd..1c3392da 100644 --- a/src/CommonLib/LdapConnectionWrapperNew.cs +++ b/src/CommonLib/LdapConnectionWrapperNew.cs @@ -12,12 +12,18 @@ public class LdapConnectionWrapperNew private string _configurationSearchBase; private string _schemaSearchBase; private string _server; + public string Guid { get; set; } private const string Unknown = "UNKNOWN"; + public bool GlobalCatalog; + public string PoolIdentifier; - public LdapConnectionWrapperNew(LdapConnection connection, ISearchResultEntry entry) + public LdapConnectionWrapperNew(LdapConnection connection, ISearchResultEntry entry, bool globalCatalog, string poolIdentifier) { Connection = connection; _searchResultEntry = entry; + Guid = new Guid().ToString(); + GlobalCatalog = globalCatalog; + PoolIdentifier = poolIdentifier; } public void CopyContexts(LdapConnectionWrapperNew other) { @@ -89,4 +95,19 @@ public void SaveContext(NamingContext context, string searchBase) throw new ArgumentOutOfRangeException(nameof(context), context, null); } } + + protected bool Equals(LdapConnectionWrapperNew other) { + return Guid == other.Guid; + } + + public override bool Equals(object obj) { + if (ReferenceEquals(null, obj)) return false; + if (ReferenceEquals(this, obj)) return true; + if (obj.GetType() != this.GetType()) return false; + return Equals((LdapConnectionWrapperNew)obj); + } + + public override int GetHashCode() { + return (Guid != null ? Guid.GetHashCode() : 0); + } } \ No newline at end of file diff --git a/src/CommonLib/LdapQuerySetupResult.cs b/src/CommonLib/LdapQuerySetupResult.cs new file mode 100644 index 00000000..4c8ed2cc --- /dev/null +++ b/src/CommonLib/LdapQuerySetupResult.cs @@ -0,0 +1,12 @@ +using System.DirectoryServices; +using System.DirectoryServices.Protocols; + +namespace SharpHoundCommonLib; + +public class LdapQuerySetupResult { + public LdapConnectionWrapperNew ConnectionWrapper { get; set; } + public SearchRequest SearchRequest { get; set; } + public string Server { get; set; } + public bool Success { get; set; } + public string Message { get; set; } +} \ No newline at end of file