diff --git a/src/CommonLib/ConnectionPoolManager.cs b/src/CommonLib/ConnectionPoolManager.cs index f7f015f8..e649a653 100644 --- a/src/CommonLib/ConnectionPoolManager.cs +++ b/src/CommonLib/ConnectionPoolManager.cs @@ -7,7 +7,17 @@ using SharpHoundCommonLib.Processors; namespace SharpHoundCommonLib { - public class ConnectionPoolManager : IDisposable{ + public interface ILdapConnectionProvider { + Task<(bool Success, string Message)> TestDomainConnection(string identifier, bool globalCatalog); + Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetLdapConnection( + string identifier, bool globalCatalog); + Task<(bool Success, LdapConnectionWrapper connectionWrapper, string Message)> GetLdapConnectionForServer( + string identifier, string server, bool globalCatalog); + void ReleaseConnection(LdapConnectionWrapper connectionWrapper, bool connectionFaulted = false); + void Dispose(); + } + + public class ConnectionPoolManager : ILdapConnectionProvider, IDisposable{ private readonly ConcurrentDictionary _pools = new(); private readonly LdapConfig _ldapConfig; private readonly string[] _translateNames = { "Administrator", "admin" }; diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index bc087b69..5aedc4aa 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -53,7 +53,7 @@ private readonly ConcurrentDictionary private readonly string[] _translateNames = { "Administrator", "admin" }; private LdapConfig _ldapConfig = new(); - private ConnectionPoolManager _connectionPool; + private ILdapConnectionProvider _connectionPool; private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2); private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20); @@ -82,6 +82,13 @@ public LdapUtils() { _connectionPool = new ConnectionPoolManager(_ldapConfig, _log); } + public LdapUtils(ILdapConnectionProvider ldapConnectionProvider) { + _nativeMethods = new NativeMethods(); + _portScanner = new PortScanner(); + _log = Logging.LogProvider.CreateLogger("LDAPUtils"); + _connectionPool = ldapConnectionProvider; + } + public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null) { _nativeMethods = nativeMethods ?? new NativeMethods(); _portScanner = scanner ?? new PortScanner(); diff --git a/test/unit/LdapUtilsQueryTest.cs b/test/unit/LdapUtilsQueryTest.cs new file mode 100644 index 00000000..c7a2b54a --- /dev/null +++ b/test/unit/LdapUtilsQueryTest.cs @@ -0,0 +1,135 @@ +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Moq; +using System.DirectoryServices.Protocols; +using SharpHoundCommonLib; + +public class RangedRetrievalTests +{ + private Mock _mockConnectionPool; + private LdapUtils _utils; + + public RangedRetrievalTests() + { + _mockConnectionPool = new Mock(); + _utils = new LdapUtils(); + } + + // [Fact] + // public async Task RangedRetrieval_SuccessfulRetrieval_ReturnsExpectedResults() + // { + // // Arrange + // var distinguishedName = "CN=TestUser,DC=example,DC=com"; + // var attributeName = "member"; + // var domain = "example.com"; + + // var connectionWrapper = new Mock(); + // var connection = new Mock(); + // connectionWrapper.SetupGet(x => x.Connection).Returns(connection.Object); + + // _mockConnectionPool.Setup(x => x.GetLdapConnection(domain, false)) + // .ReturnsAsync((true, connectionWrapper.Object, null)); + + // var searchResponse = new Mock(); + // var entry = new SearchResultEntry + // { + // Attributes = + // { + // new DirectoryAttribute("member;range=0-*", "CN=Member1,DC=example,DC=com", "CN=Member2,DC=example,DC=com") + // } + // }; + // searchResponse.Entries.Add(entry); + + // connection.Setup(x => x.SendRequest(It.IsAny())) + // .Returns(searchResponse); + + // // Act + // var results = new List>(); + // await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName)) + // { + // results.Add(result); + // } + + // // Assert + // Assert.Equal(2, results.Count); + // Assert.True(results[0].IsSuccess); + // Assert.Equal("CN=Member1,DC=example,DC=com", results[0].Value); + // Assert.True(results[1].IsSuccess); + // Assert.Equal("CN=Member2,DC=example,DC=com", results[1].Value); + // } + + [Fact] + public async Task RangedRetrieval_ConnectionFailure_ReturnsFailResult() + { + // Arrange + var distinguishedName = "CN=TestUser,DC=example,DC=com"; + var attributeName = "member"; + + // Act + var results = new List>(); + await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName)) + { + results.Add(result); + } + + // Assert + Assert.Single(results); + Assert.False(results[0].IsSuccess); + Assert.Equal("All attempted connections failed", results[0].Error); + } + + // [Fact] + // public async Task RangedRetrieval_ServerDown_RetriesAndRecovers() + // { + // // Arrange + // var distinguishedName = "CN=TestUser,DC=example,DC=com"; + // var attributeName = "member"; + // var domain = "example.com"; + + // var connectionWrapper = new Mock(); + // var connection = new Mock(); + + // // TODO : setup + + // // Act + // var results = new List>(); + // await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName)) + // { + // results.Add(result); + // } + + // // TODO Assert + // } + + [Fact] + public async Task RangedRetrieval_CancellationRequested_StopsRetrieval() + { + // Arrange + var distinguishedName = "CN=TestUser,DC=example,DC=com"; + var attributeName = "member"; + var domain = "example.com"; + + var connectionWrapper = new Mock(null, null, false, string.Empty); + var connection = new Mock(); + + _mockConnectionPool.Setup(x => x.GetLdapConnection(domain, false)) + .ReturnsAsync((true, connectionWrapper.Object, null)); + + _utils = new LdapUtils(_mockConnectionPool.Object); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act + var results = new List>(); + await foreach (var result in _utils.RangedRetrieval(distinguishedName, attributeName, cts.Token)) + { + results.Add(result); + } + + // Assert + Assert.False(results[0].IsSuccess); + } +} \ No newline at end of file