From 9c767f92ef33e0e34f5852886f7776b3bbe041e0 Mon Sep 17 00:00:00 2001 From: rvazarkar Date: Wed, 3 Jul 2024 16:55:52 -0400 Subject: [PATCH] wip: more fixes identified during during testing --- src/CommonLib/ConnectionPoolManager.cs | 13 ++++---- src/CommonLib/Extensions.cs | 22 ++++++------- src/CommonLib/LDAPConfig.cs | 6 +++- src/CommonLib/LdapConnectionPool.cs | 31 ++++++++++++------- src/CommonLib/LdapConnectionWrapper.cs | 4 +-- src/CommonLib/LdapResult.cs | 4 +++ src/CommonLib/LdapUtils.cs | 43 +++++++++----------------- src/CommonLib/Result.cs | 2 +- 8 files changed, 63 insertions(+), 62 deletions(-) diff --git a/src/CommonLib/ConnectionPoolManager.cs b/src/CommonLib/ConnectionPoolManager.cs index afc7d299..95fc9615 100644 --- a/src/CommonLib/ConnectionPoolManager.cs +++ b/src/CommonLib/ConnectionPoolManager.cs @@ -24,6 +24,7 @@ public ConnectionPoolManager(LDAPConfig config, ILogger log = null, PortScanner public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) { //I dont think this is possible, but at least account for it if (!_pools.TryGetValue(connectionWrapper.PoolIdentifier, out var pool)) { + _log.LogWarning("Could not find pool for {Identifier}", connectionWrapper.PoolIdentifier); connectionWrapper.Connection.Dispose(); return; } @@ -41,9 +42,9 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn string identifier, bool globalCatalog) { var resolved = ResolveIdentifier(identifier); - if (!_pools.TryGetValue(identifier, out var pool)) { - pool = new LdapConnectionPool(resolved, _ldapConfig,scanner: _portScanner); - _pools.TryAdd(identifier, pool); + if (!_pools.TryGetValue(resolved, out var pool)) { + pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner); + _pools.TryAdd(resolved, pool); } if (globalCatalog) { @@ -56,9 +57,9 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn string identifier, string server, bool globalCatalog) { var resolved = ResolveIdentifier(identifier); - if (!_pools.TryGetValue(identifier, out var pool)) { - pool = new LdapConnectionPool(resolved, _ldapConfig,scanner: _portScanner); - _pools.TryAdd(identifier, pool); + if (!_pools.TryGetValue(resolved, out var pool)) { + pool = new LdapConnectionPool(resolved, identifier, _ldapConfig,scanner: _portScanner); + _pools.TryAdd(resolved, pool); } return await pool.GetConnectionForSpecificServerAsync(server, globalCatalog); diff --git a/src/CommonLib/Extensions.cs b/src/CommonLib/Extensions.cs index ea4d58dc..869ef104 100644 --- a/src/CommonLib/Extensions.cs +++ b/src/CommonLib/Extensions.cs @@ -446,6 +446,7 @@ public static bool GetLabel(this DirectoryEntry entry, out Label type) { } private static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType, string[] objectClasses, int flags, out Label type) { + type = Label.Base; if (objectIdentifier != null && WellKnownPrincipal.GetWellKnownPrincipal(objectIdentifier, out var principal)) { type = principal.ObjectType; return true; @@ -474,28 +475,26 @@ private static bool ResolveLabel(string objectIdentifier, string distinguishedNa if (objectClasses.Contains(GroupPolicyContainerClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.GPO; - if (objectClasses.Contains(OrganizationalUnitClass, StringComparer.InvariantCultureIgnoreCase)) + else if (objectClasses.Contains(OrganizationalUnitClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.OU; - if (objectClasses.Contains(DomainClass, StringComparer.InvariantCultureIgnoreCase)) + else if (objectClasses.Contains(DomainClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.Domain; - if (objectClasses.Contains(ContainerClass, StringComparer.InvariantCultureIgnoreCase)) + else if (objectClasses.Contains(ContainerClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.Container; - if (objectClasses.Contains(ConfigurationClass, StringComparer.InvariantCultureIgnoreCase)) + else if (objectClasses.Contains(ConfigurationClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.Configuration; - if (objectClasses.Contains(PKICertificateTemplateClass, StringComparer.InvariantCultureIgnoreCase)) + else if (objectClasses.Contains(PKICertificateTemplateClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.CertTemplate; - if (objectClasses.Contains(PKIEnrollmentServiceClass, StringComparer.InvariantCultureIgnoreCase)) + else if (objectClasses.Contains(PKIEnrollmentServiceClass, StringComparer.InvariantCultureIgnoreCase)) type = Label.EnterpriseCA; - if (objectClasses.Contains(CertificationAuthorityClass, StringComparer.InvariantCultureIgnoreCase)) { + else if (objectClasses.Contains(CertificationAuthorityClass, StringComparer.InvariantCultureIgnoreCase)) { if (distinguishedName.Contains(DirectoryPaths.RootCALocation)) type = Label.RootCA; if (distinguishedName.Contains(DirectoryPaths.AIACALocation)) type = Label.AIACA; if (distinguishedName.Contains(DirectoryPaths.NTAuthStoreLocation)) type = Label.NTAuthStore; - } - - if (objectClasses.Contains(OIDContainerClass, StringComparer.InvariantCultureIgnoreCase)) { + }else if (objectClasses.Contains(OIDContainerClass, StringComparer.InvariantCultureIgnoreCase)) { if (distinguishedName.StartsWith(DirectoryPaths.OIDContainerLocation, StringComparison.InvariantCultureIgnoreCase)) type = Label.Container; @@ -505,8 +504,7 @@ private static bool ResolveLabel(string objectIdentifier, string distinguishedNa } } - type = Label.Base; - return false; + return type != Label.Base; } /// diff --git a/src/CommonLib/LDAPConfig.cs b/src/CommonLib/LDAPConfig.cs index e8bb26e4..c4cdd9ff 100644 --- a/src/CommonLib/LDAPConfig.cs +++ b/src/CommonLib/LDAPConfig.cs @@ -8,6 +8,7 @@ public class LDAPConfig public string Password { get; set; } = null; public string Server { get; set; } = null; public int Port { get; set; } = 0; + public int SSLPort { get; set; } = 0; public bool ForceSSL { get; set; } = false; public bool DisableSigning { get; set; } = false; public bool DisableCertVerification { get; set; } = false; @@ -16,7 +17,10 @@ public class LDAPConfig //Returns the port for connecting to LDAP. Will always respect a user's overridden config over anything else public int GetPort(bool ssl) { - if (Port != 0) + if (ssl && SSLPort != 0) { + return SSLPort; + } + if (!ssl && Port != 0) { return Port; } diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index 1cd52345..a555cf57 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -16,19 +16,20 @@ namespace SharpHoundCommonLib { public class LdapConnectionPool : IDisposable{ private readonly ConcurrentBag _connections; private readonly ConcurrentBag _globalCatalogConnection; - private static readonly ConcurrentDictionary DomainCache = new(); private readonly SemaphoreSlim _semaphore; private readonly string _identifier; + private readonly string _poolIdentifier; private readonly LDAPConfig _ldapConfig; private readonly ILogger _log; private readonly PortScanner _portScanner; private readonly NativeMethods _nativeMethods; - public LdapConnectionPool(string identifier, LDAPConfig config, int maxConnections = 10, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { + public LdapConnectionPool(string identifier, string poolIdentifier, LDAPConfig config, int maxConnections = 10, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { _connections = new ConcurrentBag(); _globalCatalogConnection = new ConcurrentBag(); _semaphore = new SemaphoreSlim(maxConnections, maxConnections); _identifier = identifier; + _poolIdentifier = poolIdentifier; _ldapConfig = config; _log = log ?? Logging.LogProvider.CreateLogger("LdapConnectionPool"); _portScanner = scanner ?? new PortScanner(); @@ -81,6 +82,7 @@ public LdapConnectionPool(string identifier, LDAPConfig config, int maxConnectio } public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) { + _semaphore.Release(); if (!connectionFaulted) { if (connectionWrapper.GlobalCatalog) { _globalCatalogConnection.Add(connectionWrapper); @@ -92,8 +94,6 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn else { connectionWrapper.Connection.Dispose(); } - - _semaphore.Release(); } public void Dispose() { @@ -108,7 +108,7 @@ public void Dispose() { } if (CreateLdapConnection(_identifier.ToUpper().Trim(), globalCatalog, out var connectionWrapper)) { - _log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1", _identifier); + _log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1. SSL: {SSl}", _identifier, connectionWrapper.Connection.SessionOptions.SecureSocketLayer); return (true, connectionWrapper, ""); } @@ -194,7 +194,7 @@ private bool CreateLdapConnection(string target, bool globalCatalog, out LdapConnectionWrapper connection) { var baseConnection = CreateBaseConnection(target, true, globalCatalog); if (TestLdapConnection(baseConnection, out var result)) { - connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _identifier); + connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _poolIdentifier); return true; } @@ -212,7 +212,7 @@ private bool CreateLdapConnection(string target, bool globalCatalog, baseConnection = CreateBaseConnection(target, false, globalCatalog); if (TestLdapConnection(baseConnection, out result)) { - connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _identifier); + connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _poolIdentifier); return true; } @@ -229,19 +229,26 @@ private bool CreateLdapConnection(string target, bool globalCatalog, private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl, bool globalCatalog) { + _log.LogDebug("Creating connection for identifier {Identifier}", directoryIdentifier); var port = globalCatalog ? _ldapConfig.GetGCPort(ssl) : _ldapConfig.GetPort(ssl); var identifier = new LdapDirectoryIdentifier(directoryIdentifier, port, false, false); var connection = new LdapConnection(identifier) { Timeout = new TimeSpan(0, 0, 5, 0) }; - + //These options are important! connection.SessionOptions.ProtocolVersion = 3; //Referral chasing does not work with paged searches connection.SessionOptions.ReferralChasing = ReferralChasingOptions.None; if (ssl) connection.SessionOptions.SecureSocketLayer = true; - - connection.SessionOptions.Sealing = !_ldapConfig.DisableSigning; - connection.SessionOptions.Signing = !_ldapConfig.DisableSigning; - + + if (_ldapConfig.DisableSigning || ssl) { + connection.SessionOptions.Signing = false; + connection.SessionOptions.Sealing = false; + } + else { + connection.SessionOptions.Signing = true; + connection.SessionOptions.Sealing = true; + } + if (_ldapConfig.DisableCertVerification) connection.SessionOptions.VerifyServerCertificate = (_, _) => true; diff --git a/src/CommonLib/LdapConnectionWrapper.cs b/src/CommonLib/LdapConnectionWrapper.cs index 12d521b6..0482603d 100644 --- a/src/CommonLib/LdapConnectionWrapper.cs +++ b/src/CommonLib/LdapConnectionWrapper.cs @@ -11,8 +11,8 @@ public class LdapConnectionWrapper { private string _schemaSearchBase; private string _server; private string Guid { get; set; } - public bool GlobalCatalog; - public string PoolIdentifier; + public readonly bool GlobalCatalog; + public readonly string PoolIdentifier; public LdapConnectionWrapper(LdapConnection connection, ISearchResultEntry entry, bool globalCatalog, string poolIdentifier) { diff --git a/src/CommonLib/LdapResult.cs b/src/CommonLib/LdapResult.cs index 7a951597..f565bbcf 100644 --- a/src/CommonLib/LdapResult.cs +++ b/src/CommonLib/LdapResult.cs @@ -14,6 +14,10 @@ protected LdapResult(T value, bool success, string error, string queryInfo, int public new static LdapResult Ok(T value) { return new LdapResult(value, true, string.Empty, null, 0); } + + public new static LdapResult Fail() { + return new LdapResult(default, false, string.Empty, null, 0); + } public static LdapResult Fail(string message, LdapQueryParameters queryInfo) { return new LdapResult(default, false, message, queryInfo.GetQueryInfo(), 0); diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index 23d0e296..0f235c1c 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -77,20 +77,14 @@ public LdapUtils() { _nativeMethods = new NativeMethods(); _portScanner = new PortScanner(); _log = Logging.LogProvider.CreateLogger("LDAPUtils"); - _connectionPool = new ConnectionPoolManager(_ldapConfig); + _connectionPool = new ConnectionPoolManager(_ldapConfig, _log); } public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null) { _nativeMethods = nativeMethods ?? new NativeMethods(); _portScanner = scanner ?? new PortScanner(); _log = log ?? Logging.LogProvider.CreateLogger("LDAPUtils"); - _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); - } - - public void SetLDAPConfig(LDAPConfig config) { - _ldapConfig = config; - _connectionPool.Dispose(); - _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); + _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner:_portScanner); } public async IAsyncEnumerable> RangedRetrieval(string distinguishedName, @@ -316,7 +310,7 @@ await _connectionPool.GetLdapConnection(queryParameters.DomainName, } catch (Exception e) { tempResult = - LdapResult.Fail($"PagedQuery - Caught unrecoverable exception: {e.Message}", + LdapResult.Fail($"Query - Caught unrecoverable exception: {e.Message}", queryParameters); } @@ -337,12 +331,11 @@ await _connectionPool.GetLdapConnection(queryParameters.DomainName, break; } } - + + _connectionPool.ReleaseConnection(connectionWrapper); foreach (SearchResultEntry entry in response.Entries) { yield return LdapResult.Ok(new SearchResultEntryWrapper(entry, this)); } - - _connectionPool.ReleaseConnection(connectionWrapper); } public async IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, @@ -469,13 +462,13 @@ public async IAsyncEnumerable> PagedQuery(LdapQue continue; } - foreach (ISearchResultEntry entry in response.Entries) { + foreach (SearchResultEntry entry in response.Entries) { if (cancellationToken.IsCancellationRequested) { _connectionPool.ReleaseConnection(connectionWrapper); yield break; } - yield return LdapResult.Ok(entry); + yield return LdapResult.Ok(new SearchResultEntryWrapper(entry, this)); } if (pageResponse.Cookie.Length == 0 || response.Entries.Count == 0 || @@ -527,7 +520,7 @@ public async IAsyncEnumerable> PagedQuery(LdapQue DomainName = tempDomain, LDAPFilter = CommonFilters.SpecificSID(sid), Attributes = CommonProperties.TypeResolutionProps - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { type = result.Value.GetLabel(); @@ -574,7 +567,7 @@ public async IAsyncEnumerable> PagedQuery(LdapQue DomainName = domain, LDAPFilter = CommonFilters.SpecificGUID(guid), Attributes = CommonProperties.TypeResolutionProps - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { type = result.Value.GetLabel(); @@ -868,7 +861,7 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF Attributes = new[] { LDAPProperties.DistinguishedName }, GlobalCatalog = true, LDAPFilter = new LDAPFilter().AddDomains(CommonFilters.SpecificSID(domainSid)).GetFilter() - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { return (true, Helpers.DistinguishedNameToDomain(result.Value.DistinguishedName)); @@ -880,7 +873,7 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF GlobalCatalog = true, LDAPFilter = new LDAPFilter().AddFilter("(objectclass=trusteddomain)", true) .AddFilter($"(securityidentifier={Helpers.ConvertSidToHexSid(domainSid)})", true).GetFilter() - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { return (true, Helpers.DistinguishedNameToDomain(result.Value.DistinguishedName)); @@ -891,7 +884,7 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF Attributes = new[] { LDAPProperties.DistinguishedName }, LDAPFilter = new LDAPFilter().AddFilter("(objectclass=domaindns)", true) .AddFilter(CommonFilters.SpecificSID(domainSid), true).GetFilter() - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { return (true, Helpers.DistinguishedNameToDomain(result.Value.DistinguishedName)); @@ -946,7 +939,7 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF DomainName = domainName, Attributes = new[] { LDAPProperties.ObjectSID }, LDAPFilter = new LDAPFilter().AddFilter(CommonFilters.DomainControllers, true).GetFilter() - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { var sid = result.Value.GetSid(); @@ -1064,7 +1057,7 @@ public bool GetDomain(out Domain domain) { DomainName = domain, Attributes = CommonProperties.TypeResolutionProps, LDAPFilter = $"(samaccountname={name})" - }).FirstAsync(); + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { type = result.Value.GetLabel(); @@ -1227,13 +1220,7 @@ public bool GetDomain(out Domain domain) { NamingContext = NamingContext.Configuration, RelativeSearchBase = DirectoryPaths.CertTemplateLocation, LDAPFilter = filter.GetFilter(), - }).DefaultIfEmpty(null).FirstAsync(); - - if (result == null) { - _log.LogWarning("Could not find certificate template with {PropertyName}:{PropertyValue}", - propertyName, propertyName); - return (false, null); - } + }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (!result.IsSuccess) { _log.LogWarning( diff --git a/src/CommonLib/Result.cs b/src/CommonLib/Result.cs index 390bb5ac..c73cfec3 100644 --- a/src/CommonLib/Result.cs +++ b/src/CommonLib/Result.cs @@ -22,7 +22,7 @@ public static Result Ok(T value) { public class Result { public string Error { get; set; } - public bool IsSuccess => Error == null && Success; + public bool IsSuccess => string.IsNullOrWhiteSpace(Error) && Success; private bool Success { get; set; } protected Result(bool success, string error) {