Skip to content

Commit

Permalink
Merge branch 'v4' into add-props
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasBK authored Aug 13, 2024
2 parents c9245f0 + d8e6757 commit 58f47ba
Show file tree
Hide file tree
Showing 7 changed files with 725 additions and 541 deletions.
59 changes: 48 additions & 11 deletions src/CommonLib/ConnectionPoolManager.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.DirectoryServices;
using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Processors;

namespace SharpHoundCommonLib {
public class ConnectionPoolManager : IDisposable{
internal class ConnectionPoolManager : IDisposable{
private readonly ConcurrentDictionary<string, LdapConnectionPool> _pools = new();
private readonly LdapConfig _ldapConfig;
private readonly string[] _translateNames = { "Administrator", "admin" };
Expand All @@ -21,6 +23,35 @@ public ConnectionPoolManager(LdapConfig config, ILogger log = null, PortScanner
_portScanner = scanner ?? new PortScanner();
}

public IAsyncEnumerable<Result<string>> RangedRetrieval(string distinguishedName,
string attributeName, CancellationToken cancellationToken = new()) {
var domain = Helpers.DistinguishedNameToDomain(distinguishedName);

if (!GetPool(domain, out var pool)) {
return new List<Result<string>> {Result<string>.Fail("Failed to resolve a connection pool")}.ToAsyncEnumerable();
}

return pool.RangedRetrieval(distinguishedName, attributeName, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
if (!GetPool(queryParameters.DomainName, out var pool)) {
return new List<LdapResult<IDirectoryObject>> {LdapResult<IDirectoryObject>.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable();
}

return pool.PagedQuery(queryParameters, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
if (!GetPool(queryParameters.DomainName, out var pool)) {
return new List<LdapResult<IDirectoryObject>> {LdapResult<IDirectoryObject>.Fail("Failed to resolve a connection pool", queryParameters)}.ToAsyncEnumerable();
}

return pool.Query(queryParameters, cancellationToken);
}

public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false) {
if (connectionWrapper == null) {
return;
Expand All @@ -41,18 +72,27 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn
return (success, message);
}

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
string identifier, bool globalCatalog) {
private bool GetPool(string identifier, out LdapConnectionPool pool) {
if (identifier == null) {
return (false, default, "Provided a null identifier for the connection");
pool = default;
return false;
}
var resolved = ResolveIdentifier(identifier);

if (!_pools.TryGetValue(resolved, out var pool)) {
var resolved = ResolveIdentifier(identifier);
if (!_pools.TryGetValue(resolved, out pool)) {
pool = new LdapConnectionPool(identifier, resolved, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
}

return true;
}

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
string identifier, bool globalCatalog) {
if (!GetPool(identifier, out var pool)) {
return (false, default, $"Unable to resolve a pool for {identifier}");
}

if (globalCatalog) {
return await pool.GetGlobalCatalogConnectionAsync();
}
Expand All @@ -61,11 +101,8 @@ public void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool conn

public async Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetLdapConnectionForServer(
string identifier, string server, bool globalCatalog) {
var resolved = ResolveIdentifier(identifier);

if (!_pools.TryGetValue(resolved, out var pool)) {
pool = new LdapConnectionPool(resolved, identifier, _ldapConfig,scanner: _portScanner);
_pools.TryAdd(resolved, pool);
if (!GetPool(identifier, out var pool)) {
return (false, default, $"Unable to resolve a pool for {identifier}");
}

return await pool.GetConnectionForSpecificServerAsync(server, globalCatalog);
Expand Down
41 changes: 41 additions & 0 deletions src/CommonLib/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,47 @@ public async ValueTask<bool> MoveNextAsync() {

public T Current => _current;
}

internal static IAsyncEnumerable<T> ToAsyncEnumerable<T>(this IEnumerable<T> source) {
return source switch {
ICollection<T> collection => new IAsyncEnumerableCollectionAdapter<T>(collection),
_ => null
};
}

private sealed class IAsyncEnumerableCollectionAdapter<T> : IAsyncEnumerable<T> {
private readonly IAsyncEnumerator<T> _enumerator;

public IAsyncEnumerableCollectionAdapter(ICollection<T> source) {
_enumerator = new IAsyncEnumeratorCollectionAdapter<T>(source);
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken()) {
return _enumerator;
}
}

private sealed class IAsyncEnumeratorCollectionAdapter<T> : IAsyncEnumerator<T> {
private readonly IEnumerable<T> _source;
private IEnumerator<T> _enumerator;

public IAsyncEnumeratorCollectionAdapter(ICollection<T> source) {
_source = source;
}

public ValueTask DisposeAsync() {
_enumerator = null;
return new ValueTask(Task.CompletedTask);
}

public ValueTask<bool> MoveNextAsync() {
if (_enumerator == null) {
_enumerator = _source.GetEnumerator();
}
return new ValueTask<bool>(_enumerator.MoveNext());
}

public T Current => _enumerator.Current;
}


public static string LdapValue(this SecurityIdentifier s)
Expand Down
20 changes: 20 additions & 0 deletions src/CommonLib/LdapConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.DirectoryServices.Protocols;
using System.Text;

namespace SharpHoundCommonLib
{
Expand All @@ -13,6 +14,7 @@ public class LdapConfig
public bool DisableSigning { get; set; } = false;
public bool DisableCertVerification { get; set; } = false;
public AuthType AuthType { get; set; } = AuthType.Kerberos;
public int MaxConcurrentQueries { get; set; } = 15;

//Returns the port for connecting to LDAP. Will always respect a user's overridden config over anything else
public int GetPort(bool ssl)
Expand All @@ -32,5 +34,23 @@ public int GetGCPort(bool ssl)
{
return ssl ? 3269 : 3268;
}

public override string ToString() {
var sb = new StringBuilder();
sb.AppendLine($"Server: {Server}");
sb.AppendLine($"Port: {Port}");
sb.AppendLine($"SSLPort: {GetPort(true)}");
sb.AppendLine($"ForceSSL: {GetPort(false)}");
sb.AppendLine($"AuthType: {AuthType.ToString()}");
sb.AppendLine($"MaxConcurrentQueries: {MaxConcurrentQueries}");
if (!string.IsNullOrWhiteSpace(Username)) {
sb.AppendLine($"Username: {Username}");
}

if (!string.IsNullOrWhiteSpace(Password)) {
sb.AppendLine($"Password: {new string('*', Password.Length)}");
}
return sb.ToString();
}
}
}
Loading

0 comments on commit 58f47ba

Please sign in to comment.