Skip to content

Commit

Permalink
CreateNewConnection shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotagoblin committed Jul 25, 2024
1 parent 0c27e8b commit 9e1700c
Showing 1 changed file with 137 additions and 89 deletions.
226 changes: 137 additions & 89 deletions src/CommonLib/LdapConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.DirectoryServices.ActiveDirectory;
using System.DirectoryServices.Protocols;
using System.Net;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -39,14 +40,14 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c
public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetConnectionAsync() {
await _semaphore.WaitAsync();
if (!_connections.TryTake(out var connectionWrapper)) {
var (success, connection, message) = await CreateNewConnection();
if (!success) {
var result = await CreateNewConnection();
if (!result.IsSuccess) {
//If we didn't get a connection, immediately release the semaphore so we don't have hanging ones
_semaphore.Release();
return (false, null, message);
return (false, null, result.Error);
}

connectionWrapper = connection;
connectionWrapper = result.Value;
}

return (true, connectionWrapper, null);
Expand All @@ -57,25 +58,25 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c
await _semaphore.WaitAsync();

var result= CreateNewConnectionForServer(server, globalCatalog);
if (!result.Success) {
if (!result.IsSuccess) {
//If we didn't get a connection, immediately release the semaphore so we don't have hanging ones
_semaphore.Release();
}

return result;
return (true, result.Value, null);
}

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetGlobalCatalogConnectionAsync() {
await _semaphore.WaitAsync();
if (!_globalCatalogConnection.TryTake(out var connectionWrapper)) {
var (success, connection, message) = await CreateNewConnection(true);
if (!success) {
var result = await CreateNewConnection(true);
if (!result.IsSuccess) {
//If we didn't get a connection, immediately release the semaphore so we don't have hanging ones
_semaphore.Release();
return (false, null, message);
return (false, null, result.Error);
}

connectionWrapper = connection;
connectionWrapper = result.Value;
}

return (true, connectionWrapper, null);
Expand All @@ -102,98 +103,145 @@ public void Dispose() {
}
}

private async Task<(bool Success, LdapConnectionWrapper Connection, string Message)> CreateNewConnection(bool globalCatalog = false) {
try {
if (!string.IsNullOrWhiteSpace(_ldapConfig.Server)) {
return CreateNewConnectionForServer(_ldapConfig.Server, globalCatalog);
}
private async Task<Result<LdapConnectionWrapper>> CreateNewConnection(bool globalCatalog = false)
{
try
{
Result<LdapConnectionWrapper> result;

if (CreateLdapConnection(_identifier.ToUpper().Trim(), globalCatalog, out var connectionWrapper)) {
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1. SSL: {SSl}", _identifier, connectionWrapper.Connection.SessionOptions.SecureSocketLayer);
return (true, connectionWrapper, "");
result = CreateConnectionUsingConfiguredServer(globalCatalog);
if (result.IsSuccess)
{
return result;
}

string tempDomainName;

var dsGetDcNameResult = _nativeMethods.CallDsGetDcName(null, _identifier,
(uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY |
NetAPIEnums.DSGETDCNAME_FLAGS.DS_RETURN_DNS_NAME |
NetAPIEnums.DSGETDCNAME_FLAGS.DS_DIRECTORY_SERVICE_REQUIRED));
if (dsGetDcNameResult.IsSuccess) {
tempDomainName = dsGetDcNameResult.Value.DomainName;

if (!tempDomainName.Equals(_identifier, StringComparison.OrdinalIgnoreCase) &&
CreateLdapConnection(tempDomainName, globalCatalog, out connectionWrapper)) {
_log.LogDebug(
"Successfully created ldap connection for domain: {Domain} using strategy 2 with name {NewName}",
_identifier, tempDomainName);
return (true, connectionWrapper, "");
}

var server = dsGetDcNameResult.Value.DomainControllerName.TrimStart('\\');

var result =
await CreateLDAPConnectionWithPortCheck(server, globalCatalog);
if (result.success) {
_log.LogDebug(
"Successfully created ldap connection for domain: {Domain} using strategy 3 to server {Server}",
_identifier, server);
return (true, result.connection, "");
}
result = CreateConnectionUsingIdentifier(globalCatalog);
if (result.IsSuccess)
{
return result;
}

if (!LdapUtils.GetDomain(_identifier, _ldapConfig, out var domainObject) || domainObject.Name == null) {
//If we don't get a result here, we effectively have no other ways to resolve this domain, so we'll just have to exit out
_log.LogDebug(
"Could not get domain object from GetDomain, unable to create ldap connection for domain {Domain}",
_identifier);
return (false, null, "Unable to get domain object for further strategies");
result = await CreateConnectionUsingDsGetDcName(globalCatalog);
if (result.IsSuccess)
{
return result;
}
tempDomainName = domainObject.Name.ToUpper().Trim();

if (!tempDomainName.Equals(_identifier, StringComparison.OrdinalIgnoreCase) &&
CreateLdapConnection(tempDomainName, globalCatalog, out connectionWrapper)) {
_log.LogDebug(
"Successfully created ldap connection for domain: {Domain} using strategy 4 with name {NewName}",
_identifier, tempDomainName);
return (true, connectionWrapper, "");
result = await CreateConnectionUsingGetDomain(globalCatalog);
if (result.IsSuccess)
{
return result;
}

var primaryDomainController = domainObject.PdcRoleOwner.Name;
var portConnectionResult =
await CreateLDAPConnectionWithPortCheck(primaryDomainController, globalCatalog);
if (portConnectionResult.success) {
_log.LogDebug(
"Successfully created ldap connection for domain: {Domain} using strategy 5 with to pdc {Server}",
_identifier, primaryDomainController);
return (true, portConnectionResult.connection, "");

return Result<LdapConnectionWrapper>.Fail("All attempted connections failed");
}
catch (Exception e)
{
_log.LogInformation(e, "Unable to connect to domain {Domain} using any strategy", _identifier);
return Result<LdapConnectionWrapper>.Fail($"Exception occurred: {e.Message}");
}
}

private Result<LdapConnectionWrapper> CreateConnectionUsingConfiguredServer(bool globalCatalog)
{
if (!string.IsNullOrWhiteSpace(_ldapConfig.Server))
{
return CreateNewConnectionForServer(_ldapConfig.Server, globalCatalog);
}
return Result<LdapConnectionWrapper>.Fail("No server configured");
}

private Result<LdapConnectionWrapper> CreateConnectionUsingIdentifier(bool globalCatalog)
{
if (CreateLdapConnection(_identifier.ToUpper().Trim(), globalCatalog, out var connectionWrapper))
{
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 1. SSL: {SSl}", _identifier, connectionWrapper.Connection.SessionOptions.SecureSocketLayer);
return Result<LdapConnectionWrapper>.Ok(connectionWrapper);
}
return Result<LdapConnectionWrapper>.Fail("Failed to create connection using identifier");
}

private async Task<Result<LdapConnectionWrapper>> CreateConnectionUsingDsGetDcName(bool globalCatalog)
{
var dsGetDcNameResult = _nativeMethods.CallDsGetDcName(null, _identifier,
(uint)(NetAPIEnums.DSGETDCNAME_FLAGS.DS_FORCE_REDISCOVERY |
NetAPIEnums.DSGETDCNAME_FLAGS.DS_RETURN_DNS_NAME |
NetAPIEnums.DSGETDCNAME_FLAGS.DS_DIRECTORY_SERVICE_REQUIRED));

if (!dsGetDcNameResult.IsSuccess)
{
return Result<LdapConnectionWrapper>.Fail("DsGetDcName call failed");
}

var tempDomainName = dsGetDcNameResult.Value.DomainName;

if (!tempDomainName.Equals(_identifier, StringComparison.OrdinalIgnoreCase))
{
if (CreateLdapConnection(tempDomainName, globalCatalog, out var connectionWrapper))
{
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 2 with name {NewName}", _identifier, tempDomainName);
return Result<LdapConnectionWrapper>.Ok(connectionWrapper);
}

foreach (DomainController dc in domainObject.DomainControllers) {
portConnectionResult =
await CreateLDAPConnectionWithPortCheck(dc.Name, globalCatalog);
if (portConnectionResult.success) {
_log.LogDebug(
"Successfully created ldap connection for domain: {Domain} using strategy 6 with to pdc {Server}",
_identifier, primaryDomainController);
return (true, portConnectionResult.connection, "");
}
}

var server = dsGetDcNameResult.Value.DomainControllerName.TrimStart('\\');
var result = await CreateLDAPConnectionWithPortCheck(server, globalCatalog);

if (result.success)
{
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 3 to server {Server}", _identifier, server);
return Result<LdapConnectionWrapper>.Ok(result.connection);
}

return Result<LdapConnectionWrapper>.Fail("Failed to create connection using DsGetDcName");
}

private async Task<Result<LdapConnectionWrapper>> CreateConnectionUsingGetDomain(bool globalCatalog)
{
if (!LdapUtils.GetDomain(_identifier, _ldapConfig, out var domainObject) || domainObject.Name == null)
{
_log.LogDebug("Could not get domain object from GetDomain, unable to create ldap connection for domain {Domain}", _identifier);
return Result<LdapConnectionWrapper>.Fail("Unable to get domain object for further strategies");
}

var tempDomainName = domainObject.Name.ToUpper().Trim();

if (!tempDomainName.Equals(_identifier, StringComparison.OrdinalIgnoreCase) &&
CreateLdapConnection(tempDomainName, globalCatalog, out var connectionWrapper))
{
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 4 with name {NewName}", _identifier, tempDomainName);
return Result<LdapConnectionWrapper>.Ok(connectionWrapper);
}

var primaryDomainController = domainObject.PdcRoleOwner.Name;
var portConnectionResult = await CreateLDAPConnectionWithPortCheck(primaryDomainController, globalCatalog);

if (portConnectionResult.success)
{
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 5 with to pdc {Server}", _identifier, primaryDomainController);
return Result<LdapConnectionWrapper>.Ok(portConnectionResult.connection); ;
}

foreach (DomainController dc in domainObject.DomainControllers)
{
portConnectionResult = await CreateLDAPConnectionWithPortCheck(dc.Name, globalCatalog);
if (portConnectionResult.success)
{
_log.LogDebug("Successfully created ldap connection for domain: {Domain} using strategy 6 with to dc {Server}", _identifier, dc.Name);
return Result<LdapConnectionWrapper>.Ok(portConnectionResult.connection);
}
} catch (Exception e) {
_log.LogInformation(e, "We will not be able to connect to domain {Domain} by any strategy, leaving it.", _identifier);
}

return (false, null, "All attempted connections failed");
return Result<LdapConnectionWrapper>.Fail("Failed to create connection using GetDomain");
}

private (bool Success, LdapConnectionWrapper Connection, string Message ) CreateNewConnectionForServer(string identifier, bool globalCatalog = false) {
if (CreateLdapConnection(identifier, globalCatalog, out var serverConnection)) {
return (true, serverConnection, "");

private Result<LdapConnectionWrapper> CreateNewConnectionForServer(string identifier, bool globalCatalog = false)
{
if (CreateLdapConnection(identifier, globalCatalog, out var serverConnection))
{
return Result<LdapConnectionWrapper>.Ok(serverConnection);
}

return (false, null, $"Failed to create ldap connection for {identifier}");
return Result<LdapConnectionWrapper>.Fail($"Failed to create ldap connection for {identifier}");
}

private bool CreateLdapConnection(string target, bool globalCatalog,
out LdapConnectionWrapper connection) {
var baseConnection = CreateBaseConnection(target, true, globalCatalog);
Expand Down

0 comments on commit 9e1700c

Please sign in to comment.