Skip to content

Commit

Permalink
Breaking dc connection cache out of utils
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotagoblin committed Jun 10, 2024
1 parent ee7d009 commit 1cef248
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 60 deletions.
80 changes: 80 additions & 0 deletions src/CommonLib/DCConnectionCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using System.Collections.Concurrent;
using System.DirectoryServices.Protocols;

namespace SharpHoundCommonLib
{
public class DCConnectionCache
{
private readonly ConcurrentDictionary<LDAPConnectionCacheKey, LdapConnection> _ldapConnectionCache;

public DCConnectionCache()
{
_ldapConnectionCache = new ConcurrentDictionary<LDAPConnectionCacheKey, LdapConnection>();
}

public bool TryGet(string domainName, bool isGlobalCatalog, out LdapConnection connection)
{
var key = GetKey(domainName, isGlobalCatalog);
return _ldapConnectionCache.TryGetValue(key, out connection);
}

public LdapConnection AddOrUpdate(string domainName, bool isGlobalCatalog, LdapConnection connection)
{
var cacheKey = GetKey(domainName, isGlobalCatalog);
return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
existingConnection.Dispose();
return connection;
});
}

public LdapConnection TryAdd(string domainName, bool isGlobalCatalog, LdapConnection connection)
{
var cacheKey = GetKey(domainName, isGlobalCatalog);
return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
connection.Dispose();
return existingConnection;
});
}

private LDAPConnectionCacheKey GetKey(string domainName, bool isGlobalCatalog)
{
return new LDAPConnectionCacheKey(domainName, isGlobalCatalog);
}

private class LDAPConnectionCacheKey
{
public bool GlobalCatalog { get; }
public string Domain { get; }
public string Server { get; set; }

public LDAPConnectionCacheKey(string domain, bool globalCatalog)
{
GlobalCatalog = globalCatalog;
Domain = domain;
}

protected bool Equals(LDAPConnectionCacheKey other)
{
return GlobalCatalog == other.GlobalCatalog && Domain == other.Domain;
}

public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((LDAPConnectionCacheKey)obj);
}

public override int GetHashCode()
{
unchecked
{
return (GlobalCatalog.GetHashCode() * 397) ^ (Domain != null ? Domain.GetHashCode() : 0);
}
}
}
}
}
36 changes: 0 additions & 36 deletions src/CommonLib/LDAPConnectionCacheKey.cs

This file was deleted.

34 changes: 10 additions & 24 deletions src/CommonLib/LDAPUtilsNew.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class LDAPUtilsNew
private LDAPConfig _ldapConfig = new();
private readonly ILogger _log;
//This cache is indexed by domain sid
private readonly ConcurrentDictionary<LDAPConnectionCacheKey, LdapConnection> _ldapConnectionCache = new();
private readonly DCConnectionCache _ldapConnectionCache;
private readonly ConcurrentDictionary<string, Domain> _domainCache = new();
private readonly string[] _translateNames = { "Administrator", "admin" };
private readonly PortScanner _portScanner;
Expand Down Expand Up @@ -198,14 +198,14 @@ await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(false))))

private LdapConnection CheckCacheConnection(LdapConnection connection, string domainName, bool globalCatalog, bool forceCreateNewConnection)
{
LDAPConnectionCacheKey cacheKey;
string cacheIdentifier;
if (_ldapConfig.Server != null)
{
cacheKey = new LDAPConnectionCacheKey(_ldapConfig.Server, globalCatalog);
cacheIdentifier = _ldapConfig.Server;
}
else
{
if (!GetDomainSidFromDomainName(domainName, out var cacheIdentifier))
if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier))
{
//This is kinda gross, but its another way to get the correct domain sid
if (!connection.GetNamingContextSearchBase(NamingContexts.Default, out var searchBase) || !GetDomainSidFromConnection(connection, searchBase, out cacheIdentifier))
Expand All @@ -214,50 +214,36 @@ private LdapConnection CheckCacheConnection(LdapConnection connection, string do
* If we get here, we couldn't resolve a domain sid, which is hella bad, but we also want to keep from creating a shitton of new connections
* Cache using the domainname and pray it all works out
*/
cacheIdentifier = domainName.ToUpper().Trim();
cacheIdentifier = domainName;
}
}

cacheKey = new LDAPConnectionCacheKey(cacheIdentifier, globalCatalog);
}

if (forceCreateNewConnection)
{
return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
existingConnection.Dispose();
return connection;
});
return _ldapConnectionCache.AddOrUpdate(cacheIdentifier, globalCatalog, connection);
}

return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
connection.Dispose();
return existingConnection;
});
return _ldapConnectionCache.TryAdd(cacheIdentifier, globalCatalog, connection);
}

private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnection connection)
{
LDAPConnectionCacheKey cacheKey;
//If server is set via our config, we'll always just use this as the cache key
if (_ldapConfig.Server != null)
{
cacheKey = new LDAPConnectionCacheKey(_ldapConfig.Server, globalCatalog);
return _ldapConnectionCache.TryGetValue(cacheKey, out connection);
return _ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection);
}

if (GetDomainSidFromDomainName(domain, out var domainSid))
{
cacheKey = new LDAPConnectionCacheKey(domainSid, globalCatalog);
if (_ldapConnectionCache.TryGetValue(cacheKey, out connection))
if (_ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection))
{
return true;
}
}

cacheKey = new LDAPConnectionCacheKey(domain.ToUpper().Trim(), globalCatalog);
return _ldapConnectionCache.TryGetValue(cacheKey, out connection);
return _ldapConnectionCache.TryGet(domain, globalCatalog, out connection);
}


Expand Down

0 comments on commit 1cef248

Please sign in to comment.