diff --git a/src/CommonLib/DCConnectionCache.cs b/src/CommonLib/DCConnectionCache.cs new file mode 100644 index 00000000..fe1c4967 --- /dev/null +++ b/src/CommonLib/DCConnectionCache.cs @@ -0,0 +1,80 @@ +using System.Collections.Concurrent; +using System.DirectoryServices.Protocols; + +namespace SharpHoundCommonLib +{ + public class DCConnectionCache + { + private readonly ConcurrentDictionary _ldapConnectionCache; + + public DCConnectionCache() + { + _ldapConnectionCache = new ConcurrentDictionary(); + } + + public bool TryGet(string domainName, bool isGlobalCatalog, out LdapConnection connection) + { + var key = GetKey(domainName, isGlobalCatalog); + return _ldapConnectionCache.TryGetValue(key, out connection); + } + + public LdapConnection AddOrUpdate(string domainName, bool isGlobalCatalog, LdapConnection connection) + { + var cacheKey = GetKey(domainName, isGlobalCatalog); + return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) => + { + existingConnection.Dispose(); + return connection; + }); + } + + public LdapConnection TryAdd(string domainName, bool isGlobalCatalog, LdapConnection connection) + { + var cacheKey = GetKey(domainName, isGlobalCatalog); + return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) => + { + connection.Dispose(); + return existingConnection; + }); + } + + private LDAPConnectionCacheKey GetKey(string domainName, bool isGlobalCatalog) + { + return new LDAPConnectionCacheKey(domainName, isGlobalCatalog); + } + + private class LDAPConnectionCacheKey + { + public bool GlobalCatalog { get; } + public string Domain { get; } + public string Server { get; set; } + + public LDAPConnectionCacheKey(string domain, bool globalCatalog) + { + GlobalCatalog = globalCatalog; + Domain = domain; + } + + protected bool Equals(LDAPConnectionCacheKey other) + { + return GlobalCatalog == other.GlobalCatalog && Domain == other.Domain; + } + + public override bool Equals(object obj) + { + if (ReferenceEquals(null, obj)) return false; + if (ReferenceEquals(this, obj)) return true; + if (obj.GetType() != this.GetType()) return false; + return Equals((LDAPConnectionCacheKey)obj); + } + + public override int GetHashCode() + { + unchecked + { + return (GlobalCatalog.GetHashCode() * 397) ^ (Domain != null ? Domain.GetHashCode() : 0); + } + } + } + } +} \ No newline at end of file diff --git a/src/CommonLib/LDAPConnectionCacheKey.cs b/src/CommonLib/LDAPConnectionCacheKey.cs deleted file mode 100644 index f1f76464..00000000 --- a/src/CommonLib/LDAPConnectionCacheKey.cs +++ /dev/null @@ -1,36 +0,0 @@ -namespace SharpHoundCommonLib -{ - public class LDAPConnectionCacheKey - { - public bool GlobalCatalog { get; } - public string Domain { get; } - public string Server { get; set; } - - public LDAPConnectionCacheKey(string domain, bool globalCatalog) - { - GlobalCatalog = globalCatalog; - Domain = domain; - } - - protected bool Equals(LDAPConnectionCacheKey other) - { - return GlobalCatalog == other.GlobalCatalog && Domain == other.Domain; - } - - public override bool Equals(object obj) - { - if (ReferenceEquals(null, obj)) return false; - if (ReferenceEquals(this, obj)) return true; - if (obj.GetType() != this.GetType()) return false; - return Equals((LDAPConnectionCacheKey)obj); - } - - public override int GetHashCode() - { - unchecked - { - return (GlobalCatalog.GetHashCode() * 397) ^ (Domain != null ? Domain.GetHashCode() : 0); - } - } - } -} \ No newline at end of file diff --git a/src/CommonLib/LDAPUtilsNew.cs b/src/CommonLib/LDAPUtilsNew.cs index 8f8ccb9f..1586db8f 100644 --- a/src/CommonLib/LDAPUtilsNew.cs +++ b/src/CommonLib/LDAPUtilsNew.cs @@ -23,7 +23,7 @@ public class LDAPUtilsNew private LDAPConfig _ldapConfig = new(); private readonly ILogger _log; //This cache is indexed by domain sid - private readonly ConcurrentDictionary _ldapConnectionCache = new(); + private readonly DCConnectionCache _ldapConnectionCache; private readonly ConcurrentDictionary _domainCache = new(); private readonly string[] _translateNames = { "Administrator", "admin" }; private readonly PortScanner _portScanner; @@ -198,14 +198,14 @@ await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(false)))) private LdapConnection CheckCacheConnection(LdapConnection connection, string domainName, bool globalCatalog, bool forceCreateNewConnection) { - LDAPConnectionCacheKey cacheKey; + string cacheIdentifier; if (_ldapConfig.Server != null) { - cacheKey = new LDAPConnectionCacheKey(_ldapConfig.Server, globalCatalog); + cacheIdentifier = _ldapConfig.Server; } else { - if (!GetDomainSidFromDomainName(domainName, out var cacheIdentifier)) + if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier)) { //This is kinda gross, but its another way to get the correct domain sid if (!connection.GetNamingContextSearchBase(NamingContexts.Default, out var searchBase) || !GetDomainSidFromConnection(connection, searchBase, out cacheIdentifier)) @@ -214,50 +214,36 @@ private LdapConnection CheckCacheConnection(LdapConnection connection, string do * If we get here, we couldn't resolve a domain sid, which is hella bad, but we also want to keep from creating a shitton of new connections * Cache using the domainname and pray it all works out */ - cacheIdentifier = domainName.ToUpper().Trim(); + cacheIdentifier = domainName; } } - - cacheKey = new LDAPConnectionCacheKey(cacheIdentifier, globalCatalog); } if (forceCreateNewConnection) { - return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) => - { - existingConnection.Dispose(); - return connection; - }); + return _ldapConnectionCache.AddOrUpdate(cacheIdentifier, globalCatalog, connection); } - return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) => - { - connection.Dispose(); - return existingConnection; - }); + return _ldapConnectionCache.TryAdd(cacheIdentifier, globalCatalog, connection); } private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnection connection) { - LDAPConnectionCacheKey cacheKey; //If server is set via our config, we'll always just use this as the cache key if (_ldapConfig.Server != null) { - cacheKey = new LDAPConnectionCacheKey(_ldapConfig.Server, globalCatalog); - return _ldapConnectionCache.TryGetValue(cacheKey, out connection); + return _ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection); } if (GetDomainSidFromDomainName(domain, out var domainSid)) { - cacheKey = new LDAPConnectionCacheKey(domainSid, globalCatalog); - if (_ldapConnectionCache.TryGetValue(cacheKey, out connection)) + if (_ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection)) { return true; } } - cacheKey = new LDAPConnectionCacheKey(domain.ToUpper().Trim(), globalCatalog); - return _ldapConnectionCache.TryGetValue(cacheKey, out connection); + return _ldapConnectionCache.TryGet(domain, globalCatalog, out connection); }