Skip to content

Commit

Permalink
chore: run formatter, add a small sleep to prevent race
Browse files Browse the repository at this point in the history
  • Loading branch information
rvazarkar committed Feb 6, 2024
1 parent efa6a28 commit 42f0805
Showing 1 changed file with 52 additions and 27 deletions.
79 changes: 52 additions & 27 deletions src/CommonLib/LDAPUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private static readonly ConcurrentDictionary<string, ResolvedWellKnownPrincipal>
private LDAPConfig _ldapConfig = new();
private readonly ManualResetEvent _connectionResetEvent = new(false);
private readonly object _lockObj = new();


/// <summary>
/// Creates a new instance of LDAP Utils with defaults
Expand Down Expand Up @@ -104,11 +104,13 @@ public void SetLDAPConfig(LDAPConfig config)
{
kv.Value.Dispose();
}

_globalCatalogConnections.Clear();
foreach (var kv in _ldapConnections)
{
kv.Value.Dispose();
}

_ldapConnections.Clear();
}

Expand Down Expand Up @@ -232,28 +234,35 @@ public TypedPrincipal ResolveIDAndType(string id, string fallbackDomain)
return new TypedPrincipal(id, type);
}

public TypedPrincipal ResolveCertTemplateByProperty(string propValue, string propertyName, string containerDN, string domainName)
public TypedPrincipal ResolveCertTemplateByProperty(string propValue, string propertyName, string containerDN,
string domainName)
{
var filter = new LDAPFilter().AddCertificateTemplates().AddFilter(propertyName + "=" + propValue, true);
var res = QueryLDAP(filter.GetFilter(), SearchScope.OneLevel,
CommonProperties.TypeResolutionProps, adsPath: containerDN, domainName: domainName);
CommonProperties.TypeResolutionProps, adsPath: containerDN, domainName: domainName);

if (res == null)
{
_log.LogWarning("Could not find certificate template with '{propertyName}:{propValue}' under {containerDN}; null result", propertyName, propValue, containerDN);
_log.LogWarning(
"Could not find certificate template with '{propertyName}:{propValue}' under {containerDN}; null result",
propertyName, propValue, containerDN);
return null;
}

List<ISearchResultEntry> resList = new List<ISearchResultEntry>(res);
if (resList.Count == 0)
{
_log.LogWarning("Could not find certificate template with '{propertyName}:{propValue}' under {containerDN}; empty list", propertyName, propValue, containerDN);
_log.LogWarning(
"Could not find certificate template with '{propertyName}:{propValue}' under {containerDN}; empty list",
propertyName, propValue, containerDN);
return null;
}

if (resList.Count > 1)
{
_log.LogWarning("Found more than one certificate template with '{propertyName}:{propValue}' under {containerDN}", propertyName, propValue, containerDN);
_log.LogWarning(
"Found more than one certificate template with '{propertyName}:{propValue}' under {containerDN}",
propertyName, propValue, containerDN);
return null;
}

Expand Down Expand Up @@ -523,7 +532,8 @@ public IEnumerable<string> DoRangedRetrieval(string distinguishedName, string at
}
catch (Exception e)
{
_log.LogError(e, "Error doing ranged retrieval for {Attribute} on {Dn}", attributeName, distinguishedName);
_log.LogError(e, "Error doing ranged retrieval for {Attribute} on {Dn}", attributeName,
distinguishedName);
yield break;
}

Expand Down Expand Up @@ -847,7 +857,8 @@ public IEnumerable<ISearchResultEntry> QueryLDAP(string ldapFilter, SearchScope
if (queryParams.Exception != null)
{
_log.LogWarning("Failed to setup LDAP Query Filter: {Message}", queryParams.Exception.Message);
if (throwException) throw new LDAPQueryException("Failed to setup LDAP Query Filter", queryParams.Exception);
if (throwException)
throw new LDAPQueryException("Failed to setup LDAP Query Filter", queryParams.Exception);
yield break;
}

Expand All @@ -871,21 +882,21 @@ public IEnumerable<ISearchResultEntry> QueryLDAP(string ldapFilter, SearchScope
if (response != null)
pageResponse = (PageResultResponseControl)response.Controls
.Where(x => x is PageResultResponseControl).DefaultIfEmpty(null).FirstOrDefault();
}catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown &&
retryCount < MaxRetries)
}
catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown &&
retryCount < MaxRetries)
{
/*A ServerDown exception indicates that our connection is no longer valid for one of many reasons.
However, this function is generally called by multiple threads, so we need to be careful in recreating
the connection. Using a semaphore, we can ensure that only one thread is actually recreating the connection
while the other threads that hit the ServerDown exception simply wait. The initial caller will hold the semaphore
and do a backoff delay before trying to make a new connection which will replace the existing connection in the
and do a backoff delay before trying to make a new connection which will replace the existing connection in the
_ldapConnections cache. Other threads will retrieve the new connection from the cache instead of making a new one
This minimizes overhead of new connections while still fixing our core problem.*/


//Always increment retry count
retryCount++;

//Attempt to acquire a lock
if (Monitor.TryEnter(_lockObj))
{
Expand Down Expand Up @@ -917,13 +928,17 @@ public IEnumerable<ISearchResultEntry> QueryLDAP(string ldapFilter, SearchScope
{
//If someone else is holding the reset event, we want to just wait and then pull the newly created connection out of the cache
//This event will be released after the first entrant thread is done making a new connection
//The thread.sleep is to prevent a potential, very unlikely race
Thread.Sleep(50);
_connectionResetEvent.WaitOne();
conn = CreateNewConnection(domainName, globalCatalog);
}

backoffDelay = GetNextBackoff(retryCount);
continue;
}catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.Busy && retryCount < MaxRetries) {
}
catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.Busy && retryCount < MaxRetries)
{
retryCount++;
backoffDelay = GetNextBackoff(retryCount);
continue;
Expand All @@ -950,7 +965,7 @@ public IEnumerable<ISearchResultEntry> QueryLDAP(string ldapFilter, SearchScope
$"LDAP Exception in Loop: {le.ErrorCode}. {le.ServerErrorMessage}. {le.Message}. Filter: {ldapFilter}. Domain: {domainName}",
le);
}

yield break;
}
catch (Exception e)
Expand Down Expand Up @@ -984,7 +999,8 @@ public IEnumerable<ISearchResultEntry> QueryLDAP(string ldapFilter, SearchScope
}
}

private LdapConnection CreateNewConnection(string domainName = null, bool globalCatalog = false, bool skipCache = false)
private LdapConnection CreateNewConnection(string domainName = null, bool globalCatalog = false,
bool skipCache = false)
{
var task = globalCatalog
? Task.Run(() => CreateGlobalCatalogConnection(domainName, _ldapConfig.AuthType))
Expand Down Expand Up @@ -1394,7 +1410,8 @@ private SearchRequest CreateSearchRequest(string filter, SearchScope scope, stri
var domain = GetDomain(domainName)?.Name ?? domainName;

if (domain == null)
throw new LDAPQueryException($"Unable to create search request: GetDomain call failed for {domainName}");
throw new LDAPQueryException(
$"Unable to create search request: GetDomain call failed for {domainName}");

var adPath = adsPath?.Replace("LDAP://", "") ?? $"DC={domain.Replace(".", ",DC=")}";

Expand Down Expand Up @@ -1425,7 +1442,9 @@ private async Task<LdapConnection> CreateGlobalCatalogConnection(string domainNa
var domain = GetDomain(domainName);
if (domain == null)
{
_log.LogDebug("Unable to create global catalog connection for domain {DomainName}: GetDomain failed", domainName);
_log.LogDebug(
"Unable to create global catalog connection for domain {DomainName}: GetDomain failed",
domainName);
throw new LDAPQueryException($"GetDomain call failed for {domainName}");
}

Expand Down Expand Up @@ -1479,8 +1498,10 @@ private async Task<LdapConnection> CreateLDAPConnection(string domainName = null
var domain = GetDomain(domainName);
if (domain == null)
{
_log.LogDebug("Unable to create ldap connection for domain {DomainName}: GetDomain failed", domainName);
throw new LDAPQueryException($"Error creating LDAP connection: GetDomain call failed for {domainName}");
_log.LogDebug("Unable to create ldap connection for domain {DomainName}: GetDomain failed",
domainName);
throw new LDAPQueryException(
$"Error creating LDAP connection: GetDomain call failed for {domainName}");
}

if (!_domainControllerCache.TryGetValue(domain.Name, out targetServer))
Expand Down Expand Up @@ -1623,7 +1644,8 @@ public int GetDomainRangeSize(string domainName = null, int defaultRangeSize = 7
//Default to a page size of 750 for safety
if (domainPath == null)
{
_log.LogDebug("Unable to resolve domain {Domain} to distinguishedname to get page size", domainName ?? "current domain");
_log.LogDebug("Unable to resolve domain {Domain} to distinguishedname to get page size",
domainName ?? "current domain");
return defaultRangeSize;
}

Expand All @@ -1635,7 +1657,8 @@ public int GetDomainRangeSize(string domainName = null, int defaultRangeSize = 7
var configPath = CommonPaths.CreateDNPath(CommonPaths.QueryPolicyPath, domainPath);
var enumerable = QueryLDAP("(objectclass=*)", SearchScope.Base, null, adsPath: configPath);
var config = enumerable.DefaultIfEmpty(null).FirstOrDefault();
var pageSize = config?.GetArrayProperty(LDAPProperties.LdapAdminLimits).FirstOrDefault(x => x.StartsWith("MaxPageSize", StringComparison.OrdinalIgnoreCase));
var pageSize = config?.GetArrayProperty(LDAPProperties.LdapAdminLimits)
.FirstOrDefault(x => x.StartsWith("MaxPageSize", StringComparison.OrdinalIgnoreCase));
if (pageSize == null)
{
_log.LogDebug("No LDAPAdminLimits object found for {Domain}", domainName);
Expand All @@ -1646,7 +1669,8 @@ public int GetDomainRangeSize(string domainName = null, int defaultRangeSize = 7
if (int.TryParse(pageSize.Split('=').Last(), out parsedPageSize))
{
_ldapRangeSizeCache.TryAdd(domainPath.ToUpper(), parsedPageSize);
_log.LogInformation("Found page size {PageSize} for {Domain}", parsedPageSize, domainName ?? "current domain");
_log.LogInformation("Found page size {PageSize} for {Domain}", parsedPageSize,
domainName ?? "current domain");
return parsedPageSize;
}

Expand Down Expand Up @@ -1700,12 +1724,13 @@ public string GetSchemaPath(string domainName)

public bool IsDomainController(string computerObjectId, string domainName)
{
var filter = new LDAPFilter().AddFilter(LDAPProperties.ObjectSID + "=" + computerObjectId, true).AddFilter(CommonFilters.DomainControllers, true);
var filter = new LDAPFilter().AddFilter(LDAPProperties.ObjectSID + "=" + computerObjectId, true)
.AddFilter(CommonFilters.DomainControllers, true);
var res = QueryLDAP(filter.GetFilter(), SearchScope.Subtree,
CommonProperties.ObjectID, domainName: domainName);
CommonProperties.ObjectID, domainName: domainName);
if (res.Count() > 0)
return true;
return false;
}
}
}
}

0 comments on commit 42f0805

Please sign in to comment.