Skip to content

Commit

Permalink
WIP: utils query tests, especially around exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotagoblin committed Jul 29, 2024
1 parent 1ea7d1e commit 8a9ef4e
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/CommonLib/ConnectionPoolManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
using SharpHoundCommonLib.Processors;

namespace SharpHoundCommonLib {
public class ConnectionPoolManager : IDisposable{
public interface ILdapConnectionProvider {
Task<(bool Success, string Message)> TestDomainConnection(string identifier, bool globalCatalog);
Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection(
string identifier, bool globalCatalog);
Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetLdapConnectionForServer(
string identifier, string server, bool globalCatalog);
void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false);
void Dispose();
}

public class ConnectionPoolManager : ILdapConnectionProvider, IDisposable{
private readonly ConcurrentDictionary<string, LdapConnectionPool> _pools = new();
private readonly LdapConfig _ldapConfig;
private readonly string[] _translateNames = { "Administrator", "admin" };
Expand Down
9 changes: 8 additions & 1 deletion src/CommonLib/LdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private readonly ConcurrentDictionary<string, string>
private readonly string[] _translateNames = { "Administrator", "admin" };
private LdapConfig _ldapConfig = new();

private ConnectionPoolManager _connectionPool;
private ILdapConnectionProvider _connectionPool;

private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2);
private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20);
Expand Down Expand Up @@ -82,6 +82,13 @@ public LdapUtils() {
_connectionPool = new ConnectionPoolManager(_ldapConfig, _log);
}

public LdapUtils(ILdapConnectionProvider ldapConnectionProvider) {
_nativeMethods = new NativeMethods();
_portScanner = new PortScanner();
_log = Logging.LogProvider.CreateLogger("LDAPUtils");
_connectionPool = ldapConnectionProvider;
}

public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null) {
_nativeMethods = nativeMethods ?? new NativeMethods();
_portScanner = scanner ?? new PortScanner();
Expand Down
135 changes: 135 additions & 0 deletions test/unit/LdapUtilsQueryTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Moq;
using System.DirectoryServices.Protocols;
using SharpHoundCommonLib;

public class RangedRetrievalTests
{
private Mock<ILdapConnectionProvider> _mockConnectionPool;
private LdapUtils _utils;

public RangedRetrievalTests()
{
_mockConnectionPool = new Mock<ILdapConnectionProvider>();
_utils = new LdapUtils();
}

// [Fact]
// public async Task RangedRetrieval_SuccessfulRetrieval_ReturnsExpectedResults()
// {
// // Arrange
// var distinguishedName = "CN=TestUser,DC=example,DC=com";
// var attributeName = "member";
// var domain = "example.com";

// var connectionWrapper = new Mock<LdapConnectionWrapper>();
// var connection = new Mock<LdapConnection>();
// connectionWrapper.SetupGet(x => x.Connection).Returns(connection.Object);

// _mockConnectionPool.Setup(x => x.GetLdapConnection(domain, false))
// .ReturnsAsync((true, connectionWrapper.Object, null));

// var searchResponse = new Mock<SearchResponse>();
// var entry = new SearchResultEntry
// {
// Attributes =
// {
// new DirectoryAttribute("member;range=0-*", "CN=Member1,DC=example,DC=com", "CN=Member2,DC=example,DC=com")
// }
// };
// searchResponse.Entries.Add(entry);

// connection.Setup(x => x.SendRequest(It.IsAny<SearchRequest>()))
// .Returns(searchResponse);

// // Act
// var results = new List<Result<string>>();
// await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName))
// {
// results.Add(result);
// }

// // Assert
// Assert.Equal(2, results.Count);
// Assert.True(results[0].IsSuccess);
// Assert.Equal("CN=Member1,DC=example,DC=com", results[0].Value);
// Assert.True(results[1].IsSuccess);
// Assert.Equal("CN=Member2,DC=example,DC=com", results[1].Value);
// }

[Fact]
public async Task RangedRetrieval_ConnectionFailure_ReturnsFailResult()
{
// Arrange
var distinguishedName = "CN=TestUser,DC=example,DC=com";
var attributeName = "member";

// Act
var results = new List<Result<string>>();
await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName))
{
results.Add(result);
}

// Assert
Assert.Single(results);
Assert.False(results[0].IsSuccess);
Assert.Equal("All attempted connections failed", results[0].Error);
}

// [Fact]
// public async Task RangedRetrieval_ServerDown_RetriesAndRecovers()
// {
// // Arrange
// var distinguishedName = "CN=TestUser,DC=example,DC=com";
// var attributeName = "member";
// var domain = "example.com";

// var connectionWrapper = new Mock<LdapConnectionWrapper>();
// var connection = new Mock<LdapConnection>();

// // TODO : setup

// // Act
// var results = new List<Result<string>>();
// await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName))
// {
// results.Add(result);
// }

// // TODO Assert
// }

[Fact]
public async Task RangedRetrieval_CancellationRequested_StopsRetrieval()
{
// Arrange
var distinguishedName = "CN=TestUser,DC=example,DC=com";
var attributeName = "member";
var domain = "example.com";

var connectionWrapper = new Mock<LdapConnectionWrapper>(null, null, false, string.Empty);
var connection = new Mock<LdapConnection>();

_mockConnectionPool.Setup(x => x.GetLdapConnection(domain, false))
.ReturnsAsync((true, connectionWrapper.Object, null));

_utils = new LdapUtils(_mockConnectionPool.Object);

var cts = new CancellationTokenSource();
cts.Cancel();

// Act
var results = new List<Result<string>>();
await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName, cts.Token))
{
results.Add(result);
}

// Assert
Assert.False(results[0].IsSuccess);
}
}

0 comments on commit 8a9ef4e

Please sign in to comment.