Skip to content

Commit

Permalink
wip: implement connection pool
Browse files Browse the repository at this point in the history
  • Loading branch information
rvazarkar committed Jun 18, 2024
1 parent 3fcbc9d commit 6d5cd65
Show file tree
Hide file tree
Showing 4 changed files with 632 additions and 63 deletions.
284 changes: 222 additions & 62 deletions src/CommonLib/LDAPUtilsNew.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,48 +40,179 @@ public class LDAPUtilsNew {
private readonly object _lockObj = new();
private readonly ManualResetEvent _connectionResetEvent = new(false);

public async IAsyncEnumerable<LdapResult<ISearchResultEntry>> PagedQuery(LdapQueryParameters queryParameters,
public async IAsyncEnumerable<LdapResult<ISearchResultEntry>> Query(LdapQueryParameters queryParameters,
[EnumeratorCancellation] CancellationToken cancellationToken = new()) {
//Always force create a new connection
var (success, connectionWrapper, message) = await GetLdapConnection(queryParameters.DomainName,
queryParameters.GlobalCatalog, true);
if (!success) {
_log.LogDebug("PagedQuery failure: unable to create a connection: {Reason}\n{Info}", message,
var setupResult = await SetupLdapQuery(queryParameters, true);

if (!setupResult.Success) {
_log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message,
queryParameters.GetQueryInfo());
yield return new LdapResult<ISearchResultEntry> {
Error = $"Unable to create a connection: {message}",
QueryInfo = queryParameters.GetQueryInfo()
};
yield break;
}

//This should never happen as far as I know, so just checking for safety
if (connectionWrapper == null) {
_log.LogError("PagedQuery failure: ldap connection is null\n{Info}", queryParameters.GetQueryInfo());
yield return new LdapResult<ISearchResultEntry> {
Error = "Connection is null",
QueryInfo = queryParameters.GetQueryInfo()
};
var searchRequest = setupResult.SearchRequest;
var connectionWrapper = setupResult.ConnectionWrapper;
var connection = connectionWrapper.Connection;
var serverName = setupResult.Server;

if (serverName == null) {
_log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible");
}

if (cancellationToken.IsCancellationRequested) {
yield break;
}

var connection = connectionWrapper.Connection;
var queryRetryCount = 0;
var busyRetryCount = 0;
LdapResult<ISearchResultEntry> tempResult = null;
while (queryRetryCount < MaxRetries) {
SearchResponse response = null;
try {
_log.LogTrace("Sending ldap request - {Info}", queryParameters.GetQueryInfo());
response = (SearchResponse)connection.SendRequest(searchRequest);
}
catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) {
/*
* If we dont have a servername, we're not going to be able to re-establish a connection here. Page cookies are only valid for the server they were generated on. Bail out.
*/
if (serverName == null) {
_log.LogError(
"Query - Received server down exception without a known servername. Unable to generate new connection\n{Info}",
queryParameters.GetQueryInfo());
yield break;
}

/*A ServerDown exception indicates that our connection is no longer valid for one of many reasons.
However, this function is generally called by multiple threads, so we need to be careful in recreating
the connection. Using a semaphore, we can ensure that only one thread is actually recreating the connection
while the other threads that hit the ServerDown exception simply wait. The initial caller will hold the semaphore
and do a backoff delay before trying to make a new connection which will replace the existing connection in the
_ldapConnections cache. Other threads will retrieve the new connection from the cache instead of making a new one
This minimizes overhead of new connections while still fixing our core problem.*/

//Increment our query retry count
queryRetryCount++;

//Attempt to acquire a lock
if (Monitor.TryEnter(_lockObj)) {
//Signal the reset event to ensure no everyone else waits
_connectionResetEvent.Reset();
try {
//Try up to MaxRetries time to make a new connection, ensuring we're not using the cache
for (var retryCount = 0; retryCount < MaxRetries; retryCount++) {
var backoffDelay = GetNextBackoff(retryCount);
await Task.Delay(backoffDelay, cancellationToken);
var (success, newConnectionWrapper, _) = await GetLdapConnection(queryParameters, true);
if (success) {
newConnectionWrapper.CopyContexts(connectionWrapper);
connectionWrapper.Connection.Dispose();
connectionWrapper = newConnectionWrapper;
break;
}

if (retryCount == MaxRetries - 1) {
_log.LogError("Query - Failed to get a new connection after ServerDown.\n{Info}",
queryParameters.GetQueryInfo());
yield break;
}
}
}finally{
_connectionResetEvent.Set();
Monitor.Exit(_lockObj);
}
}
else {
//If someone else is holding the reset event, we want to just wait and then pull the newly created connection out of the cache
//This event will be released after the first entrant thread is done making a new connection
//The thread.sleep is to prevent a potential, very unlikely race
Thread.Sleep(50);
_connectionResetEvent.WaitOne();

//At this point, our connection reset event should be tripped, and there should be a new connection on the cache
var (success, newConnectionWrapper, _) = await GetLdapConnection(queryParameters);
if (!success) {
_log.LogError("Query - Failed to recover from ServerDown error\n{Info}", queryParameters.GetQueryInfo());
yield break;
}

newConnectionWrapper.CopyContexts(connectionWrapper);
connectionWrapper = newConnectionWrapper;
connection = connectionWrapper.Connection;
}
}
catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) {
/*
* If we get a busy error, we want to do an exponential backoff, but maintain the current connection
* The expectation is that given enough time, the server should stop being busy and service our query appropriately
*/
busyRetryCount++;
var backoffDelay = GetNextBackoff(busyRetryCount);
await Task.Delay(backoffDelay, cancellationToken);
}
catch (LdapException le) {
//No point in printing local exceptions because they're literally worthless
tempResult = new LdapResult<ISearchResultEntry>() {
Error =
$"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})",
QueryInfo = queryParameters.GetQueryInfo()
};
}
catch (Exception e) {
tempResult = new LdapResult<ISearchResultEntry> {
Error =
$"PagedQuery - Caught unrecoverable exception: {e.Message}",
QueryInfo = queryParameters.GetQueryInfo()
};
}

//Pull the server name from the connection for retry logic later
if (!connectionWrapper.GetServer(out var serverName)) {
_log.LogDebug("PagedQuery: Failed to get server value");
serverName = null;
if (tempResult != null) {
yield return tempResult;
yield break;
}
}
}

if (!CreateSearchRequest(queryParameters, ref connectionWrapper, out var searchRequest)) {
_log.LogError("PagedQuery failure: unable to resolve search base\n{Info}", queryParameters.GetQueryInfo());
yield return new LdapResult<ISearchResultEntry> {
Error = "Unable to create search request",
QueryInfo = queryParameters.GetQueryInfo()
};
private bool SendSearchRequestWithRetryHandling(LdapConnectionWrapperNew connectionWrapper,
SearchRequest searchRequest, out SearchResponse searchResponse) {
for (var retryCount = 0; retryCount < MaxRetries; retryCount++) {
try {
searchResponse = (SearchResponse)connectionWrapper.Connection.SendRequest(searchRequest);
return true;
}
catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown) {

}
}
}

private bool SendPagedSearchRequestWithRetryHandling(LdapConnectionWrapperNew connectionWrapper,
SearchRequest searchRequest, out SearchResponse searchResponse) {




}

public async IAsyncEnumerable<LdapResult<ISearchResultEntry>> PagedQuery(LdapQueryParameters queryParameters,
[EnumeratorCancellation] CancellationToken cancellationToken = new()) {
var setupResult = await SetupLdapQuery(queryParameters, true);

if (!setupResult.Success) {
_log.LogInformation("PagedQuery - Failure during query setup: {Reason}\n{Info}", setupResult.Message,
queryParameters.GetQueryInfo());
yield break;
}

var searchRequest = setupResult.SearchRequest;
var connectionWrapper = setupResult.ConnectionWrapper;
var connection = connectionWrapper.Connection;
var serverName = setupResult.Server;

if (serverName == null) {
_log.LogWarning("PagedQuery - Failed to get a server name for connection, retry not possible");
}

var pageControl = new PageResultRequestControl(500);
searchRequest.Controls.Add(pageControl);

Expand Down Expand Up @@ -188,17 +319,16 @@ public async IAsyncEnumerable<LdapResult<ISearchResultEntry>> PagedQuery(LdapQue
Value = entry
};
}

if (pageResponse.Cookie.Length == 0 || response.Entries.Count == 0 ||
cancellationToken.IsCancellationRequested)
yield break;

pageControl.Cookie = pageResponse.Cookie;
}
}

private static TimeSpan GetNextBackoff(int retryCount)
{

private static TimeSpan GetNextBackoff(int retryCount) {
return TimeSpan.FromSeconds(Math.Min(
MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount),
MaxBackoffDelay.TotalSeconds));
Expand Down Expand Up @@ -303,6 +433,44 @@ private bool
}
}

private async Task<LdapQuerySetupResult> SetupLdapQuery(LdapQueryParameters queryParameters,
bool forceNewConnection = false) {
var result = new LdapQuerySetupResult();
var (success, connectionWrapper, message) = await GetLdapConnection(queryParameters, forceNewConnection);
if (!success) {
result.Success = false;
result.Message = $"Unable to create a connection: {message}";
return result;
}

//This should never happen as far as I know, so just checking for safety
if (connectionWrapper.Connection == null) {
result.Success = false;
result.Message = $"Connection object is null";
return result;
}

if (!CreateSearchRequest(queryParameters, ref connectionWrapper, out var searchRequest)) {
result.Success = false;
result.Message = "Failed to create search request";
return result;
}

if (GetServerFromConnection(connectionWrapper.Connection, out var server)) {
result.Server = server;
}

result.Success = true;
result.SearchRequest = searchRequest;
result.ConnectionWrapper = connectionWrapper;
return result;
}

private Task<(bool Success, LdapConnectionWrapperNew Connection, string Message )> GetLdapConnection(
LdapQueryParameters queryParameters, bool forceCreateNewConnection = false) {
return GetLdapConnection(queryParameters.DomainName, queryParameters.GlobalCatalog, forceCreateNewConnection);
}

private async Task<(bool Success, LdapConnectionWrapperNew Connection, string Message )> GetLdapConnection(
string domainName, bool globalCatalog = false,
bool forceCreateNewConnection = false) {
Expand Down Expand Up @@ -442,7 +610,8 @@ private bool
}
}

private async Task<(bool success, LdapConnectionWrapperNew connection)> CreateLDAPConnectionWithPortCheck(string target, bool globalCatalog) {
private async Task<(bool success, LdapConnectionWrapperNew connection)> CreateLDAPConnectionWithPortCheck(
string target, bool globalCatalog) {
if (globalCatalog) {
if (await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(true)) || (!_ldapConfig.ForceSSL &&
await _portScanner.CheckPort(target, _ldapConfig.GetGCPort(false))))
Expand All @@ -457,20 +626,18 @@ await _portScanner.CheckPort(target, _ldapConfig.GetPort(false))))
return (false, null);
}

private LdapConnectionWrapperNew CheckCacheConnection(LdapConnectionWrapperNew connectionWrapper, string domainName, bool globalCatalog, bool forceCreateNewConnection)
{
private LdapConnectionWrapperNew CheckCacheConnection(LdapConnectionWrapperNew connectionWrapper, string domainName,
bool globalCatalog, bool forceCreateNewConnection) {
string cacheIdentifier;
if (_ldapConfig.Server != null)
{
if (_ldapConfig.Server != null) {
cacheIdentifier = _ldapConfig.Server;
}
else
{
if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier))
{
else {
if (!GetDomainSidFromDomainName(domainName, out cacheIdentifier)) {
//This is kinda gross, but its another way to get the correct domain sid
if (!connectionWrapper.Connection.GetNamingContextSearchBase(NamingContext.Default, out var searchBase) || !GetDomainSidFromConnection(connectionWrapper.Connection, searchBase, out cacheIdentifier))
{
if (!connectionWrapper.Connection.GetNamingContextSearchBase(NamingContext.Default,
out var searchBase) || !GetDomainSidFromConnection(connectionWrapper.Connection, searchBase,
out cacheIdentifier)) {
/*
* If we get here, we couldn't resolve a domain sid, which is hella bad, but we also want to keep from creating a shitton of new connections
* Cache using the domainname and pray it all works out
Expand All @@ -479,27 +646,22 @@ private LdapConnectionWrapperNew CheckCacheConnection(LdapConnectionWrapperNew c
}
}
}

if (forceCreateNewConnection)
{

if (forceCreateNewConnection) {
return _ldapConnectionCache.AddOrUpdate(cacheIdentifier, globalCatalog, connectionWrapper);
}

return _ldapConnectionCache.TryAdd(cacheIdentifier, globalCatalog, connectionWrapper);
}

private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnectionWrapperNew connection)
{

private bool GetCachedConnection(string domain, bool globalCatalog, out LdapConnectionWrapperNew connection) {
//If server is set via our config, we'll always just use this as the cache key
if (_ldapConfig.Server != null)
{
if (_ldapConfig.Server != null) {
return _ldapConnectionCache.TryGet(_ldapConfig.Server, globalCatalog, out connection);
}

if (GetDomainSidFromDomainName(domain, out var domainSid))
{
if (_ldapConnectionCache.TryGet(domainSid, globalCatalog, out connection))
{

if (GetDomainSidFromDomainName(domain, out var domainSid)) {
if (_ldapConnectionCache.TryGet(domainSid, globalCatalog, out connection)) {
return true;
}
}
Expand Down Expand Up @@ -599,10 +761,8 @@ private LdapConnection CreateBaseConnection(string directoryIdentifier, bool ssl
connection.SessionOptions.ReferralChasing = ReferralChasingOptions.None;
if (ssl) connection.SessionOptions.SecureSocketLayer = true;

if (_ldapConfig.DisableSigning) {
connection.SessionOptions.Sealing = false;
connection.SessionOptions.Signing = false;
}
connection.SessionOptions.Sealing = !_ldapConfig.DisableSigning;
connection.SessionOptions.Signing = !_ldapConfig.DisableSigning;

if (_ldapConfig.DisableCertVerification)
connection.SessionOptions.VerifyServerCertificate = (_, _) => true;
Expand Down Expand Up @@ -683,7 +843,7 @@ private bool TestLdapConnection(LdapConnection connection, string identifier, ou
return true;
}

private SearchRequest CreateSearchRequest(string distinguishedName, string ldapFilter, SearchScope searchScope,
public static SearchRequest CreateSearchRequest(string distinguishedName, string ldapFilter, SearchScope searchScope,
string[] attributes) {
var searchRequest = new SearchRequest(distinguishedName, ldapFilter,
searchScope, attributes);
Expand Down
Loading

0 comments on commit 6d5cd65

Please sign in to comment.