Skip to content

Commit

Permalink
wip: more fixes identified during during testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rvazarkar committed Jul 3, 2024
1 parent 9a3feaf commit 9c767f9
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 62 deletions.
13 changes: 7 additions & 6 deletions src/CommonLib/ConnectionPoolManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public ConnectionPoolManager(LDAPConfig config, ILogger log = null, PortScanner
public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) {
//I dont think this is possible, but at least account for it
if (!_pools.TryGetValue(connectionWrapper.PoolIdentifier, out var pool)) {
_log.LogWarning("Could not find pool for {Identifier}", connectionWrapper.PoolIdentifier);
connectionWrapper.Connection.Dispose();
return;
}
Expand All @@ -41,9 +42,9 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn
string identifier, bool globalCatalog) {
var resolved = ResolveIdentifier(identifier);

if (!_pools.TryGetValue(identifier, out var pool)) {
pool = new LdapConnectionPool(resolved, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(identifier, pool);
if (!_pools.TryGetValue(resolved, out var pool)) {
pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
}

if (globalCatalog) {
Expand All @@ -56,9 +57,9 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn
string identifier, string server, bool globalCatalog) {
var resolved = ResolveIdentifier(identifier);

if (!_pools.TryGetValue(identifier, out var pool)) {
pool = new LdapConnectionPool(resolved, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(identifier, pool);
if (!_pools.TryGetValue(resolved, out var pool)) {
pool = new LdapConnectionPool(resolved, identifier, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
}

return await pool.GetConnectionForSpecificServerAsync(server, globalCatalog);
Expand Down
22 changes: 10 additions & 12 deletions src/CommonLib/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ public static bool GetLabel(this DirectoryEntry entry, out Label type) {
}

private static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType, string[] objectClasses, int flags, out Label type) {
type = Label.Base;
if (objectIdentifier != null && WellKnownPrincipal.GetWellKnownPrincipal(objectIdentifier, out var principal)) {
type = principal.ObjectType;
return true;
Expand Down Expand Up @@ -474,28 +475,26 @@ private static bool ResolveLabel(string objectIdentifier, string distinguishedNa

if (objectClasses.Contains(GroupPolicyContainerClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.GPO;
if (objectClasses.Contains(OrganizationalUnitClass, StringComparer.InvariantCultureIgnoreCase))
else if (objectClasses.Contains(OrganizationalUnitClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.OU;
if (objectClasses.Contains(DomainClass, StringComparer.InvariantCultureIgnoreCase))
else if (objectClasses.Contains(DomainClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.Domain;
if (objectClasses.Contains(ContainerClass, StringComparer.InvariantCultureIgnoreCase))
else if (objectClasses.Contains(ContainerClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.Container;
if (objectClasses.Contains(ConfigurationClass, StringComparer.InvariantCultureIgnoreCase))
else if (objectClasses.Contains(ConfigurationClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.Configuration;
if (objectClasses.Contains(PKICertificateTemplateClass, StringComparer.InvariantCultureIgnoreCase))
else if (objectClasses.Contains(PKICertificateTemplateClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.CertTemplate;
if (objectClasses.Contains(PKIEnrollmentServiceClass, StringComparer.InvariantCultureIgnoreCase))
else if (objectClasses.Contains(PKIEnrollmentServiceClass, StringComparer.InvariantCultureIgnoreCase))
type = Label.EnterpriseCA;
if (objectClasses.Contains(CertificationAuthorityClass, StringComparer.InvariantCultureIgnoreCase)) {
else if (objectClasses.Contains(CertificationAuthorityClass, StringComparer.InvariantCultureIgnoreCase)) {
if (distinguishedName.Contains(DirectoryPaths.RootCALocation))
type = Label.RootCA;
if (distinguishedName.Contains(DirectoryPaths.AIACALocation))
type = Label.AIACA;
if (distinguishedName.Contains(DirectoryPaths.NTAuthStoreLocation))
type = Label.NTAuthStore;
}

if (objectClasses.Contains(OIDContainerClass, StringComparer.InvariantCultureIgnoreCase)) {
}else if (objectClasses.Contains(OIDContainerClass, StringComparer.InvariantCultureIgnoreCase)) {
if (distinguishedName.StartsWith(DirectoryPaths.OIDContainerLocation,
StringComparison.InvariantCultureIgnoreCase))
type = Label.Container;
Expand All @@ -505,8 +504,7 @@ private static bool ResolveLabel(string objectIdentifier, string distinguishedNa
}
}

type = Label.Base;
return false;
return type != Label.Base;
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion src/CommonLib/LDAPConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public class LDAPConfig
public string Password { get; set; } = null;
public string Server { get; set; } = null;
public int Port { get; set; } = 0;
public int SSLPort { get; set; } = 0;
public bool ForceSSL { get; set; } = false;
public bool DisableSigning { get; set; } = false;
public bool DisableCertVerification { get; set; } = false;
Expand All @@ -16,7 +17,10 @@ public class LDAPConfig
//Returns the port for connecting to LDAP. Will always respect a user's overridden config over anything else
public int GetPort(bool ssl)
{
if (Port != 0)
if (ssl && SSLPort != 0) {
return SSLPort;
}
if (!ssl && Port != 0)
{
return Port;
}
Expand Down
31 changes: 19 additions & 12 deletions src/CommonLib/LdapConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@ namespace SharpHoundCommonLib {
public class LdapConnectionPool : IDisposable{
private readonly ConcurrentBag<LdapConnectionWrapper> _connections;
private readonly ConcurrentBag<LdapConnectionWrapper> _globalCatalogConnection;
private static readonly ConcurrentDictionary<string, Domain> DomainCache = new();
private readonly SemaphoreSlim _semaphore;
private readonly string _identifier;
private readonly string _poolIdentifier;
private readonly LDAPConfig _ldapConfig;
private readonly ILogger _log;
private readonly PortScanner _portScanner;
private readonly NativeMethods _nativeMethods;

public LdapConnectionPool(string identifier, LDAPConfig config, int maxConnections = 10, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) {
public LdapConnectionPool(string identifier, string poolIdentifier, LDAPConfig config, int maxConnections = 10, PortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) {
_connections = new ConcurrentBag<LdapConnectionWrapper>();
_globalCatalogConnection = new ConcurrentBag<LdapConnectionWrapper>();
_semaphore = new SemaphoreSlim(maxConnections, maxConnections);
_identifier = identifier;
_poolIdentifier = poolIdentifier;
_ldapConfig = config;
_log = log ?? Logging.LogProvider.CreateLogger("LdapConnectionPool");
_portScanner = scanner ?? new PortScanner();
Expand Down Expand Up @@ -81,6 +82,7 @@ public LdapConnectionPool(string identifier, LDAPConfig config, int maxConnectio
}

public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) {
_semaphore.Release();
if (!connectionFaulted) {
if (connectionWrapper.GlobalCatalog) {
_globalCatalogConnection.Add(connectionWrapper);
Expand All @@ -92,8 +94,6 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn
else {
connectionWrapper.Connection.Dispose();
}

_semaphore.Release();
}

public void Dispose() {
Expand All @@ -108,7 +108,7 @@ public void Dispose() {
}

if (CreateLdapConnection(_identifier.ToUpper().Trim(), globalCatalog, out var connectionWrapper)) {
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1", _identifier);
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1. SSL: {SSl}", _identifier, connectionWrapper.Connection.SessionOptions.SecureSocketLayer);
return (true, connectionWrapper, "");
}

Expand Down Expand Up @@ -194,7 +194,7 @@ private bool CreateLdapConnection(string target, bool globalCatalog,
out LdapConnectionWrapper connection) {
var baseConnection = CreateBaseConnection(target, true, globalCatalog);
if (TestLdapConnection(baseConnection, out var result)) {
connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _identifier);
connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _poolIdentifier);
return true;
}

Expand All @@ -212,7 +212,7 @@ private bool CreateLdapConnection(string target, bool globalCatalog,

baseConnection = CreateBaseConnection(target, false, globalCatalog);
if (TestLdapConnection(baseConnection, out result)) {
connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _identifier);
connection = new LdapConnectionWrapper(baseConnection, result.SearchResultEntry, globalCatalog, _poolIdentifier);
return true;
}

Expand All @@ -229,19 +229,26 @@ private bool CreateLdapConnection(string target, bool globalCatalog,

private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl,
bool globalCatalog) {
_log.LogDebug("Creating connection for identifier {Identifier}", directoryIdentifier);
var port = globalCatalog ? _ldapConfig.GetGCPort(ssl) : _ldapConfig.GetPort(ssl);
var identifier = new LdapDirectoryIdentifier(directoryIdentifier, port, false, false);
var connection = new LdapConnection(identifier) { Timeout = new TimeSpan(0, 0, 5, 0) };

//These options are important!
connection.SessionOptions.ProtocolVersion = 3;
//Referral chasing does not work with paged searches
connection.SessionOptions.ReferralChasing = ReferralChasingOptions.None;
if (ssl) connection.SessionOptions.SecureSocketLayer = true;

connection.SessionOptions.Sealing = !_ldapConfig.DisableSigning;
connection.SessionOptions.Signing = !_ldapConfig.DisableSigning;


if (_ldapConfig.DisableSigning || ssl) {
connection.SessionOptions.Signing = false;
connection.SessionOptions.Sealing = false;
}
else {
connection.SessionOptions.Signing = true;
connection.SessionOptions.Sealing = true;
}

if (_ldapConfig.DisableCertVerification)
connection.SessionOptions.VerifyServerCertificate = (_, _) => true;

Expand Down
4 changes: 2 additions & 2 deletions src/CommonLib/LdapConnectionWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ public class LdapConnectionWrapper {
private string _schemaSearchBase;
private string _server;
private string Guid { get; set; }
public bool GlobalCatalog;
public string PoolIdentifier;
public readonly bool GlobalCatalog;
public readonly string PoolIdentifier;

public LdapConnectionWrapper(LdapConnection connection, ISearchResultEntry entry, bool globalCatalog,
string poolIdentifier) {
Expand Down
4 changes: 4 additions & 0 deletions src/CommonLib/LdapResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ protected LdapResult(T value, bool success, string error, string queryInfo, int
public new static LdapResult<T> Ok(T value) {
return new LdapResult<T>(value, true, string.Empty, null, 0);
}

public new static LdapResult<T> Fail() {
return new LdapResult<T>(default, false, string.Empty, null, 0);
}

public static LdapResult<T> Fail(string message, LdapQueryParameters queryInfo) {
return new LdapResult<T>(default, false, message, queryInfo.GetQueryInfo(), 0);
Expand Down
Loading

0 comments on commit 9c767f9

Please sign in to comment.