Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DC Connection Cache Breakout #129

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/CommonLib/DCConnectionCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using System.Collections.Concurrent;
using System.DirectoryServices.Protocols;

namespace SharpHoundCommonLib
{
public class DCConnectionCache
{
private readonly ConcurrentDictionary<LDAPConnectionCacheKey, LdapConnection> _ldapConnectionCache;

public DCConnectionCache()
{
_ldapConnectionCache = new ConcurrentDictionary<LDAPConnectionCacheKey, LdapConnection>();
}

public bool TryGet(string key, bool isGlobalCatalog, out LdapConnection connection)
{
var cacheKey = GetKey(key, isGlobalCatalog);
return _ldapConnectionCache.TryGetValue(cacheKey, out connection);
}

public LdapConnection AddOrUpdate(string key, bool isGlobalCatalog, LdapConnection connection)
{
var cacheKey = GetKey(key, isGlobalCatalog);
return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
existingConnection.Dispose();
return connection;
});
}

public LdapConnection TryAdd(string key, bool isGlobalCatalog, LdapConnection connection)
{
var cacheKey = GetKey(key, isGlobalCatalog);
return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
connection.Dispose();
return existingConnection;
});
}

private LDAPConnectionCacheKey GetKey(string key, bool isGlobalCatalog)
{
return new LDAPConnectionCacheKey(key.ToUpper().Trim(), isGlobalCatalog);
}

private class LDAPConnectionCacheKey
{
public string Domain { get; }
public bool GlobalCatalog { get; }

public LDAPConnectionCacheKey(string domain, bool globalCatalog)
{
GlobalCatalog = globalCatalog;
Domain = domain;
}

protected bool Equals(LDAPConnectionCacheKey other)
{
return GlobalCatalog == other.GlobalCatalog && Domain == other.Domain;
}

public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((LDAPConnectionCacheKey)obj);
}

public override int GetHashCode()
{
unchecked
{
return (GlobalCatalog.GetHashCode() * 397) ^ (Domain != null ? Domain.GetHashCode() : 0);
}
}
}
}
}
36 changes: 0 additions & 36 deletions src/CommonLib/LDAPConnectionCacheKey.cs

This file was deleted.

34 changes: 10 additions & 24 deletions src/CommonLib/LDAPUtilsNew.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class LDAPUtilsNew
private LDAPConfig _ldapConfig = new();
private readonly ILogger _log;
//This cache is indexed by domain sid
private readonly ConcurrentDictionary<LDAPConnectionCacheKey, LdapConnection> _ldapConnectionCache = new();
private readonly DCConnectionCache _ldapConnectionCache;
private readonly ConcurrentDictionary<string, Domain> _domainCache = new();
private readonly string[] _translateNames = { "Administrator", "admin" };
private readonly PortScanner _portScanner;
Expand Down Expand Up @@ -198,14 +198,14 @@ await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(false))))

private LdapConnection CheckCacheConnection(LdapConnection connection, string domainName, bool globalCatalog, bool forceCreateNewConnection)
{
LDAPConnectionCacheKey cacheKey;
string cacheIdentifier;
if (_ldapConfig.Server != null)
{
cacheKey = new LDAPConnectionCacheKey(_ldapConfig.Server, globalCatalog);
cacheIdentifier = _ldapConfig.Server;
}
else
{
if (!GetDomainSidFromDomainName(domainName, out var cacheIdentifier))
if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier))
{
//This is kinda gross, but its another way to get the correct domain sid
if (!connection.GetNamingContextSearchBase(NamingContexts.Default, out var searchBase) || !GetDomainSidFromConnection(connection, searchBase, out cacheIdentifier))
Expand All @@ -214,50 +214,36 @@ private LdapConnection CheckCacheConnection(LdapConnection connection, string do
* If we get here, we couldn't resolve a domain sid, which is hella bad, but we also want to keep from creating a shitton of new connections
* Cache using the domainname and pray it all works out
*/
cacheIdentifier = domainName.ToUpper().Trim();
cacheIdentifier = domainName;
}
}

cacheKey = new LDAPConnectionCacheKey(cacheIdentifier, globalCatalog);
}

if (forceCreateNewConnection)
{
return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
existingConnection.Dispose();
return connection;
});
return _ldapConnectionCache.AddOrUpdate(cacheIdentifier, globalCatalog, connection);
}

return _ldapConnectionCache.AddOrUpdate(cacheKey, connection, (_, existingConnection) =>
{
connection.Dispose();
return existingConnection;
});
return _ldapConnectionCache.TryAdd(cacheIdentifier, globalCatalog, connection);
}

private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnection connection)
{
LDAPConnectionCacheKey cacheKey;
//If server is set via our config, we'll always just use this as the cache key
if (_ldapConfig.Server != null)
{
cacheKey = new LDAPConnectionCacheKey(_ldapConfig.Server, globalCatalog);
return _ldapConnectionCache.TryGetValue(cacheKey, out connection);
return _ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection);
}

if (GetDomainSidFromDomainName(domain, out var domainSid))
{
cacheKey = new LDAPConnectionCacheKey(domainSid, globalCatalog);
if (_ldapConnectionCache.TryGetValue(cacheKey, out connection))
if (_ldapConnectionCache.TryGet(domainSid, globalCatalog, out connection))
{
return true;
}
}

cacheKey = new LDAPConnectionCacheKey(domain.ToUpper().Trim(), globalCatalog);
return _ldapConnectionCache.TryGetValue(cacheKey, out connection);
return _ldapConnectionCache.TryGet(domain, globalCatalog, out connection);
}


Expand Down
Loading