Skip to content

Commit

Permalink
fix: only lock semaphore during active queries to prevent deadlocks
Browse files Browse the repository at this point in the history
  • Loading branch information
rvazarkar committed Aug 7, 2024
1 parent 1f5e96a commit 7c9f05f
Show file tree
Hide file tree
Showing 5 changed files with 609 additions and 497 deletions.
49 changes: 48 additions & 1 deletion src/CommonLib/ConnectionPoolManager.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.DirectoryServices;
using System.Runtime.CompilerServices;
using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Processors;

namespace SharpHoundCommonLib {
public class ConnectionPoolManager : IDisposable{
internal class ConnectionPoolManager : IDisposable{
private readonly ConcurrentDictionary<string, LdapConnectionPool> _pools = new();
private readonly LdapConfig _ldapConfig;
private readonly string[] _translateNames = { "Administrator", "admin" };
Expand All @@ -21,6 +24,35 @@ public ConnectionPoolManager(LdapConfig config, ILogger log = null, PortScanner
_portScanner = scanner ?? new PortScanner();
}

public IAsyncEnumerable<Result<string>> RangedRetrieval(string distinguishedName,
string attributeName, CancellationToken cancellationToken = new()) {
var domain = Helpers.DistinguishedNameToDomain(distinguishedName);

if (!GetPool(domain, out var pool)) {
return new List<Result<string>> {Result<string>.Fail("Failed to resolve a connection pool")}.ToAsyncEnumerable();
}

return pool.RangedRetrieval(distinguishedName, attributeName, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
if (!GetPool(queryParameters.DomainName, out var pool)) {
return new List<LdapResult<IDirectoryObject>> {LdapResult<IDirectoryObject>.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable();
}

return pool.PagedQuery(queryParameters, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
if (!GetPool(queryParameters.DomainName, out var pool)) {
return new List<LdapResult<IDirectoryObject>> {LdapResult<IDirectoryObject>.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable();
}

return pool.Query(queryParameters, cancellationToken);
}

public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) {
if (connectionWrapper == null) {
return;
Expand All @@ -41,6 +73,21 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn
return (success, message);
}

private bool GetPool(string identifier, out LdapConnectionPool pool) {
if (identifier == null) {
pool = default;
return false;
}

var resolved = ResolveIdentifier(identifier);
if (!_pools.TryGetValue(resolved, out pool)) {
pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
}

return true;
}

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
string identifier, bool globalCatalog) {
if (identifier == null) {
Expand Down
41 changes: 41 additions & 0 deletions src/CommonLib/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,47 @@ public async ValueTask<bool> MoveNextAsync() {

public T Current => _current;
}

internal static IAsyncEnumerable<T> ToAsyncEnumerable<T>(this IEnumerable<T> source) {
return source switch {
ICollection<T> collection => new IAsyncEnumerableCollectionAdapter<T>(collection),
_ => null
};
}

private sealed class IAsyncEnumerableCollectionAdapter<T> : IAsyncEnumerable<T> {
private readonly IAsyncEnumerator<T> _enumerator;

public IAsyncEnumerableCollectionAdapter(ICollection<T> source) {
_enumerator = new IAsyncEnumeratorCollectionAdapter<T>(source);
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken()) {
return _enumerator;
}
}

private sealed class IAsyncEnumeratorCollectionAdapter<T> : IAsyncEnumerator<T> {
private readonly IEnumerable<T> _source;
private IEnumerator<T> _enumerator;

public IAsyncEnumeratorCollectionAdapter(ICollection<T> source) {
_source = source;
}

public ValueTask DisposeAsync() {
_enumerator = null;
return new ValueTask(Task.CompletedTask);
}

public ValueTask<bool> MoveNextAsync() {
if (_enumerator == null) {
_enumerator = _source.GetEnumerator();
}
return new ValueTask<bool>(_enumerator.MoveNext());
}

public T Current => _enumerator.Current;
}


public static string LdapValue(this SecurityIdentifier s)
Expand Down
13 changes: 13 additions & 0 deletions src/CommonLib/LdapConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.DirectoryServices.Protocols;
using System.Text;

namespace SharpHoundCommonLib
{
Expand Down Expand Up @@ -32,5 +33,17 @@ public int GetGCPort(bool ssl)
{
return ssl ? 3269 : 3268;
}

public override string ToString() {
var sb = new StringBuilder();
sb.AppendLine($"Server: {Server}");
sb.AppendLine($"Port: {Port}");
sb.AppendLine($"SSLPort: {SSLPort}");
sb.AppendLine($"ForceSSL: {ForceSSL}");
sb.AppendLine($"AuthType: {AuthType.ToString()}");
sb.AppendLine($"Username: {Username}");
sb.AppendLine($"Password: {new string('*', Password.Length)}");
return sb.ToString();
}
}
}
Loading

0 comments on commit 7c9f05f

Please sign in to comment.