Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semaphore fix #151

Merged
merged 9 commits into from
Aug 8, 2024
59 changes: 48 additions & 11 deletions src/CommonLib/ConnectionPoolManager.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.DirectoryServices;
using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Processors;

namespace SharpHoundCommonLib {
public class ConnectionPoolManager : IDisposable{
internal class ConnectionPoolManager : IDisposable{
private readonly ConcurrentDictionary<string, LdapConnectionPool> _pools = new();
private readonly LdapConfig _ldapConfig;
private readonly string[] _translateNames = { "Administrator", "admin" };
Expand All @@ -21,6 +23,35 @@ public ConnectionPoolManager(LdapConfig config, ILogger log = null, PortScanner
_portScanner = scanner ?? new PortScanner();
}

public IAsyncEnumerable<Result<string>> RangedRetrieval(string distinguishedName,
string attributeName, CancellationToken cancellationToken = new()) {
var domain = Helpers.DistinguishedNameToDomain(distinguishedName);

if (!GetPool(domain, out var pool)) {
return new List<Result<string>> {Result<string>.Fail("Failed to resolve a connection pool")}.ToAsyncEnumerable();
}

return pool.RangedRetrieval(distinguishedName, attributeName, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
if (!GetPool(queryParameters.DomainName, out var pool)) {
return new List<LdapResult<IDirectoryObject>> {LdapResult<IDirectoryObject>.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable();
}

return pool.PagedQuery(queryParameters, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
if (!GetPool(queryParameters.DomainName, out var pool)) {
return new List<LdapResult<IDirectoryObject>> {LdapResult<IDirectoryObject>.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable();
}

return pool.Query(queryParameters, cancellationToken);
}

public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) {
if (connectionWrapper == null) {
return;
Expand All @@ -41,18 +72,27 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn
return (success, message);
}

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
string identifier, bool globalCatalog) {
private bool GetPool(string identifier, out LdapConnectionPool pool) {
if (identifier == null) {
return (false, default, "Provided a null identifier for the connection");
pool = default;
return false;
}
var resolved = ResolveIdentifier(identifier);

if (!_pools.TryGetValue(resolved, out var pool)) {
var resolved = ResolveIdentifier(identifier);
if (!_pools.TryGetValue(resolved, out pool)) {
pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
}

return true;
}

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
string identifier, bool globalCatalog) {
if (!GetPool(identifier, out var pool)) {
return (false, default, $"Unable to resolve a pool for {identifier}");
}

if (globalCatalog) {
return await pool.GetGlobalCatalogConnectionAsync();
}
Expand All @@ -61,11 +101,8 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn

public async Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetLdapConnectionForServer(
string identifier, string server, bool globalCatalog) {
var resolved = ResolveIdentifier(identifier);

if (!_pools.TryGetValue(resolved, out var pool)) {
pool = new LdapConnectionPool(resolved, identifier, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
if (!GetPool(identifier, out var pool)) {
return (false, default, $"Unable to resolve a pool for {identifier}");
}

return await pool.GetConnectionForSpecificServerAsync(server, globalCatalog);
Expand Down
41 changes: 41 additions & 0 deletions src/CommonLib/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,47 @@ public async ValueTask<bool> MoveNextAsync() {

public T Current => _current;
}

internal static IAsyncEnumerable<T> ToAsyncEnumerable<T>(this IEnumerable<T> source) {
return source switch {
ICollection<T> collection => new IAsyncEnumerableCollectionAdapter<T>(collection),
_ => null
};
}

private sealed class IAsyncEnumerableCollectionAdapter<T> : IAsyncEnumerable<T> {
private readonly IAsyncEnumerator<T> _enumerator;

public IAsyncEnumerableCollectionAdapter(ICollection<T> source) {
_enumerator = new IAsyncEnumeratorCollectionAdapter<T>(source);
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken()) {
return _enumerator;
}
}

private sealed class IAsyncEnumeratorCollectionAdapter<T> : IAsyncEnumerator<T> {
private readonly IEnumerable<T> _source;
private IEnumerator<T> _enumerator;

public IAsyncEnumeratorCollectionAdapter(ICollection<T> source) {
_source = source;
}

public ValueTask DisposeAsync() {
_enumerator = null;
return new ValueTask(Task.CompletedTask);
}

public ValueTask<bool> MoveNextAsync() {
if (_enumerator == null) {
_enumerator = _source.GetEnumerator();
}
return new ValueTask<bool>(_enumerator.MoveNext());
}

public T Current => _enumerator.Current;
}


public static string LdapValue(this SecurityIdentifier s)
Expand Down
20 changes: 20 additions & 0 deletions src/CommonLib/LdapConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.DirectoryServices.Protocols;
using System.Text;

namespace SharpHoundCommonLib
{
Expand All @@ -13,6 +14,7 @@ public class LdapConfig
public bool DisableSigning { get; set; } = false;
public bool DisableCertVerification { get; set; } = false;
public AuthType AuthType { get; set; } = AuthType.Kerberos;
public int MaxConcurrentQueries { get; set; } = 15;

//Returns the port for connecting to LDAP. Will always respect a user's overridden config over anything else
public int GetPort(bool ssl)
Expand All @@ -32,5 +34,23 @@ public int GetGCPort(bool ssl)
{
return ssl ? 3269 : 3268;
}

public override string ToString() {
var sb = new StringBuilder();
sb.AppendLine($"Server: {Server}");
sb.AppendLine($"Port: {Port}");
sb.AppendLine($"SSLPort: {GetPort(true)}");
sb.AppendLine($"ForceSSL: {GetPort(false)}");
sb.AppendLine($"AuthType: {AuthType.ToString()}");
sb.AppendLine($"MaxConcurrentQueries: {MaxConcurrentQueries}");
if (!string.IsNullOrWhiteSpace(Username)) {
sb.AppendLine($"Username: {Username}");
}

if (!string.IsNullOrWhiteSpace(Password)) {
sb.AppendLine($"Password: {new string('*', Password.Length)}");
}
return sb.ToString();
}
}
}
Loading
Loading