Skip to content

Commit

Permalink
Replace System.Linq.Async with custom implementations (#137)
Browse files Browse the repository at this point in the history
* chore: remove system.linq.async, replace with custom functions for file size issues

* chore: add some tests covering new async enumerables

* chore: minor fixes

* chore: add overload for FirstOrDefaultAsync

* chore: add some more tests
  • Loading branch information
rvazarkar authored Jul 23, 2024
1 parent f062342 commit 5d8905c
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 16 deletions.
23 changes: 23 additions & 0 deletions src/CommonLib/AsyncEnumerable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace SharpHoundCommonLib;

public static class AsyncEnumerable {
public static IAsyncEnumerable<T> Empty<T>() => EmptyAsyncEnumerable<T>.Instance;

private sealed class EmptyAsyncEnumerable<T> : IAsyncEnumerable<T> {
public static readonly EmptyAsyncEnumerable<T> Instance = new();
private readonly IAsyncEnumerator<T> _enumerator = new EmptyAsyncEnumerator<T>();
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken()) {
return _enumerator;
}
}

private sealed class EmptyAsyncEnumerator<T> : IAsyncEnumerator<T> {
public ValueTask DisposeAsync() => default;
public ValueTask<bool> MoveNextAsync() => new(false);
public T Current => default;
}
}
2 changes: 1 addition & 1 deletion src/CommonLib/DirectoryObjects/SearchResultEntryWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public SearchResultEntryWrapper(SearchResultEntry entry) {
}

public bool TryGetDistinguishedName(out string value) {
return TryGetProperty(LDAPProperties.DistinguishedName, out value);
return TryGetProperty(LDAPProperties.DistinguishedName, out value) && !string.IsNullOrWhiteSpace(value);
}

public bool TryGetProperty(string propertyName, out string value) {
Expand Down
114 changes: 106 additions & 8 deletions src/CommonLib/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using System.Collections.Generic;
using System.DirectoryServices;
using System.Linq;
using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;

Expand All @@ -16,14 +19,109 @@ static Extensions()
Log = Logging.LogProvider.CreateLogger("Extensions");
}

// internal static async Task<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> items)
// {
// var results = new List<T>();
// await foreach (var item in items
// .ConfigureAwait(false))
// results.Add(item);
// return results;
// }
public static async Task<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> items)
{
if (items == null) {
return new List<T>();
}
var results = new List<T>();
await foreach (var item in items
.ConfigureAwait(false))
results.Add(item);
return results;
}

public static async Task<T[]> ToArrayAsync<T>(this IAsyncEnumerable<T> items)
{
if (items == null) {
return Array.Empty<T>();
}
var results = new List<T>();
await foreach (var item in items
.ConfigureAwait(false))
results.Add(item);
return results.ToArray();
}

public static async Task<T> FirstOrDefaultAsync<T>(this IAsyncEnumerable<T> source,
CancellationToken cancellationToken = default) {
if (source == null) {
return default;
}

await using (var enumerator = source.GetAsyncEnumerator(cancellationToken)) {
var first = await enumerator.MoveNextAsync() ? enumerator.Current : default;
return first;
}
}

public static async Task<T> FirstOrDefaultAsync<T>(this IAsyncEnumerable<T> source, T defaultValue,
CancellationToken cancellationToken = default) {
if (source == null) {
return defaultValue;
}

await using (var enumerator = source.GetAsyncEnumerator(cancellationToken)) {
var first = await enumerator.MoveNextAsync() ? enumerator.Current : defaultValue;
return first;
}
}

public static IAsyncEnumerable<T> DefaultIfEmpty<T>(this IAsyncEnumerable<T> source,
T defaultValue, CancellationToken cancellationToken = default) {
return new DefaultIfEmptyAsyncEnumerable<T>(source, defaultValue);
}

private sealed class DefaultIfEmptyAsyncEnumerable<T> : IAsyncEnumerable<T> {
private readonly DefaultIfEmptyAsyncEnumerator<T> _enumerator;
public DefaultIfEmptyAsyncEnumerable(IAsyncEnumerable<T> source, T defaultValue) {
_enumerator = new DefaultIfEmptyAsyncEnumerator<T>(source, defaultValue);
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken()) {
return _enumerator;
}
}

private sealed class DefaultIfEmptyAsyncEnumerator<T> : IAsyncEnumerator<T> {
private readonly IAsyncEnumerable<T> _source;
private readonly T _defaultValue;
private T _current;
private bool _enumeratorDisposed;

private IAsyncEnumerator<T> _enumerator;

public DefaultIfEmptyAsyncEnumerator(IAsyncEnumerable<T> source, T defaultValue) {
_source = source;
_defaultValue = defaultValue;
}

public async ValueTask DisposeAsync() {
_enumeratorDisposed = true;
if (_enumerator != null) {
await _enumerator.DisposeAsync().ConfigureAwait(false);
_enumerator = null;
}
}

public async ValueTask<bool> MoveNextAsync() {
if (_enumeratorDisposed) {
return false;
}
_enumerator ??= _source.GetAsyncEnumerator();

if (await _enumerator.MoveNextAsync().ConfigureAwait(false)) {
_current = _enumerator.Current;
return true;
}

_current = _defaultValue;
await DisposeAsync().ConfigureAwait(false);
return true;
}

public T Current => _current;
}


public static string LdapValue(this SecurityIdentifier s)
{
Expand Down
2 changes: 2 additions & 0 deletions src/CommonLib/ILdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ IAsyncEnumerable<Result<string>> RangedRetrieval(string distinguishedName,
/// <param name="domain">The Domain object</param>
/// <returns>True if the domain was found, false if not</returns>
bool GetDomain(out System.DirectoryServices.ActiveDirectory.Domain domain);

Task<(bool Success, string ForestName)> GetForest(string domain);
/// <summary>
/// Attempts to resolve an account name to its corresponding typed principal
/// </summary>
Expand Down
8 changes: 7 additions & 1 deletion src/CommonLib/LdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ await _connectionPool.GetLdapConnectionForServer(
LDAPFilter = new LdapFilter().AddAllObjects().GetFilter(),
};

var result = await Query(queryParameters).FirstAsync();
var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.RootDomainNamingContext, out var rootNamingContext)) {
return (true, Helpers.DistinguishedNameToDomain(rootNamingContext).ToUpper());
Expand Down Expand Up @@ -1296,6 +1296,9 @@ public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() {
}

public async Task<bool> IsDomainController(string computerObjectId, string domainName) {
if (DomainControllers.ContainsKey(computerObjectId)) {
return true;
}
var resDomain = await GetDomainNameFromSid(domainName) is (false, var tempDomain) ? tempDomain : domainName;
var filter = new LdapFilter().AddFilter(CommonFilters.SpecificSID(computerObjectId), true)
.AddFilter(CommonFilters.DomainControllers, true);
Expand All @@ -1304,6 +1307,9 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
Attributes = CommonProperties.ObjectID,
LDAPFilter = filter.GetFilter(),
}).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
if (result.IsSuccess) {
DomainControllers.TryAdd(computerObjectId, new byte());
}
return result.IsSuccess;
}

Expand Down
1 change: 0 additions & 1 deletion src/CommonLib/SharpHoundCommonLib.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
<PackageReference Include="AntiXSS" Version="4.3.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.0" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>
<ItemGroup>
<Reference Include="System.DirectoryServices" />
Expand Down
92 changes: 92 additions & 0 deletions test/unit/AsyncEnumerableTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using SharpHoundCommonLib;
using Xunit;

namespace CommonLibTest;

public class AsyncEnumerableTests {
[Fact]
public async Task AsyncEnumerable_DefaultIfEmpty_Empty() {
var enumerable = AsyncEnumerable.Empty<int>().DefaultIfEmpty(1);
var e = enumerable.GetAsyncEnumerator();
var res = await e.MoveNextAsync();
Assert.True(res);
Assert.Equal(1, e.Current);
Assert.False(await e.MoveNextAsync());
}

[Fact]
public async Task AsyncEnumerable_FirstOrDefault() {
var enumerable = AsyncEnumerable.Empty<int>();
var res = await enumerable.FirstOrDefaultAsync();
Assert.Equal(0, res);
}

[Fact]
public async Task AsyncEnumerable_FirstOrDefault_WithDefault() {
var enumerable = AsyncEnumerable.Empty<int>();
var res = await enumerable.FirstOrDefaultAsync(10);
Assert.Equal(10, res);
}

[Fact]
public async Task AsyncEnumerable_CombinedOperators() {
var enumerable = AsyncEnumerable.Empty<string>();
var res = await enumerable.DefaultIfEmpty("abc").FirstOrDefaultAsync();
Assert.Equal("abc", res);
}

[Fact]
public async Task AsyncEnumerable_ToAsyncEnumerable() {
var collection = new[] {
"a", "b", "c"
};

var test = collection.ToAsyncEnumerable();

var index = 0;
await foreach (var item in test) {
Assert.Equal(collection[index], item);
index++;
}
}

[Fact]
public async Task AsyncEnumerable_FirstOrDefaultFunction() {
var test = await TestFunc().FirstOrDefaultAsync();
Assert.Equal("a", test);
}

[Fact]
public async Task AsyncEnumerable_CombinedFunction() {
var test = await TestFunc().DefaultIfEmpty("d").FirstOrDefaultAsync();
Assert.Equal("a", test);
}

[Fact]
public async Task AsyncEnumerable_FirstOrDefaultEmptyFunction() {
var test = await EmptyFunc().FirstOrDefaultAsync();
Assert.Null(test);
}

[Fact]
public async Task AsyncEnumerable_CombinedEmptyFunction() {
var test = await EmptyFunc().DefaultIfEmpty("d").FirstOrDefaultAsync();
Assert.Equal("d", test);
}

private async IAsyncEnumerable<string> TestFunc() {

Check warning on line 79 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 79 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 79 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 79 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
var collection = new[] {
"a", "b", "c"
};

foreach (var i in collection) {
yield return i;
}
}

private async IAsyncEnumerable<string> EmptyFunc() {

Check warning on line 89 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 89 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 89 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 89 in test/unit/AsyncEnumerableTests.cs

View workflow job for this annotation

GitHub Actions / build

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
yield break;
}
}
10 changes: 5 additions & 5 deletions test/unit/Facades/MockLdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public virtual IAsyncEnumerable<Result<string>> RangedRetrieval(string distingui

public async Task<(bool Success, TypedPrincipal WellKnownPrincipal)> GetWellKnownPrincipal(string securityIdentifier, string objectDomain) {
if (!WellKnownPrincipal.GetWellKnownPrincipal(securityIdentifier, out var commonPrincipal)) return (false, default);
commonPrincipal.ObjectIdentifier = ConvertWellKnownPrincipal(securityIdentifier, objectDomain);
commonPrincipal.ObjectIdentifier = await ConvertWellKnownPrincipal(securityIdentifier, objectDomain);
_seenWellKnownPrincipals.TryAdd(commonPrincipal.ObjectIdentifier, securityIdentifier);
return (true, commonPrincipal);
}
Expand Down Expand Up @@ -747,13 +747,13 @@ public bool GetDomain(out Domain domain) {
throw new NotImplementedException();
}

public string ConvertWellKnownPrincipal(string sid, string domain)
public async Task<string> ConvertWellKnownPrincipal(string sid, string domain)
{
if (!WellKnownPrincipal.GetWellKnownPrincipal(sid, out _)) return sid;

if (sid != "S-1-5-9") return $"{domain}-{sid}".ToUpper();

var forest = GetForest(domain)?.Name;
var (success, forest) = await GetForest(domain);
return $"{forest}-{sid}".ToUpper();
}

Expand Down Expand Up @@ -1052,9 +1052,9 @@ Task<bool> ILdapUtils.IsDomainController(string computerObjectId, string domainN
throw new NotImplementedException();
}

public Forest GetForest(string domainName = null)
public async Task<(bool Success, string ForestName)> GetForest(string domainName = null)
{
return _forest;
return (true, _forest.Name);
}

public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor()
Expand Down
Loading

0 comments on commit 5d8905c

Please sign in to comment.