Skip to content

Commit

Permalink
Add SetAll operator to IUpdateBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Dec 5, 2024
1 parent de529ce commit 01b9694
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 135 deletions.
2 changes: 1 addition & 1 deletion src/DataAccess/src/SIL.DataAccess/ArrayPosition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ public static class ArrayPosition
{
public const int FirstMatching = int.MaxValue;
public const int All = int.MaxValue - 1;
public const int ArrayFilter = int.MaxValue - 2;
internal const int ArrayFilter = int.MaxValue - 2;
}
11 changes: 7 additions & 4 deletions src/DataAccess/src/SIL.DataAccess/DataAccessFieldDefinition.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
namespace SIL.DataAccess;

public class DataAccessFieldDefinition<TDocument, TField>(Expression<Func<TDocument, TField>> expression)
: FieldDefinition<TDocument, TField>
public class DataAccessFieldDefinition<TDocument, TField>(
Expression<Func<TDocument, TField>> expression,
string arrayFilterId = ""
) : FieldDefinition<TDocument, TField>
{
private readonly ExpressionFieldDefinition<TDocument, TField> _internalDef = new(expression);
private readonly string _arrayFilterId = arrayFilterId;

public override RenderedFieldDefinition<TField> Render(
IBsonSerializer<TDocument> documentSerializer,
Expand All @@ -17,11 +20,11 @@ LinqProvider linqProvider
linqProvider
);
string fieldName = rendered.FieldName.Replace(ArrayPosition.All.ToString(CultureInfo.InvariantCulture), "$[]");
fieldName = fieldName.Replace(ArrayPosition.FirstMatching.ToString(CultureInfo.InvariantCulture), "$");
fieldName = fieldName.Replace(
ArrayPosition.ArrayFilter.ToString(CultureInfo.InvariantCulture),
"$[arrayFilter]"
$"$[{_arrayFilterId}]"
);
fieldName = fieldName.Replace(ArrayPosition.FirstMatching.ToString(CultureInfo.InvariantCulture), "$");
if (fieldName != rendered.FieldName)
{
return new RenderedFieldDefinition<TField>(
Expand Down
10 changes: 10 additions & 0 deletions src/DataAccess/src/SIL.DataAccess/ExpressionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,14 @@ Expression expression
finder.Visit(expression);
return finder.Value;
}

public static Expression<Func<TIn, TOut>> Concatenate<TIn, TInter, TOut>(
Expression<Func<TIn, TInter>> left,
Expression<Func<TInter, TOut>> right
)
{
ParameterReplacer replacer = new(right.Parameters[0], left.Body);
Expression merged = replacer.Visit(right.Body);
return Expression.Lambda<Func<TIn, TOut>>(merged, left.Parameters[0]);
}
}
10 changes: 2 additions & 8 deletions src/DataAccess/src/SIL.DataAccess/IRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,20 @@ public interface IRepository<T>

Task InsertAsync(T entity, CancellationToken cancellationToken = default);
Task InsertAllAsync(IReadOnlyCollection<T> entities, CancellationToken cancellationToken = default);

Task<T?> UpdateAsync(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
bool upsert = false,
bool returnOriginal = false,
CancellationToken cancellationToken = default
);
Task<int> UpdateAllAsync<TFilter>(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
string jsonArrayFilterDefinition,
CancellationToken cancellationToken = default
);

Task<int> UpdateAllAsync(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
UpdateOptions? updateOptions = null,
CancellationToken cancellationToken = default
);

Task<T?> DeleteAsync(Expression<Func<T, bool>> filter, CancellationToken cancellationToken = default);
Task<int> DeleteAllAsync(Expression<Func<T, bool>> filter, CancellationToken cancellationToken = default);
Task<ISubscription<T>> SubscribeAsync(
Expand Down
9 changes: 8 additions & 1 deletion src/DataAccess/src/SIL.DataAccess/IUpdateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ public interface IUpdateBuilder<T>

IUpdateBuilder<T> RemoveAll<TItem>(
Expression<Func<T, IEnumerable<TItem>?>> field,
Expression<Func<TItem, bool>> predicate
Expression<Func<TItem, bool>>? predicate = null
);

IUpdateBuilder<T> Remove<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value);

IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value);

IUpdateBuilder<T> SetAll<TItem, TField>(
Expression<Func<T, IEnumerable<TItem>?>> collectionField,
Expression<Func<TItem, TField>> itemField,
TField value,
Expression<Func<TItem, bool>>? predicate = null
);
}
11 changes: 0 additions & 11 deletions src/DataAccess/src/SIL.DataAccess/MemoryRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,9 @@ public async Task InsertAllAsync(IReadOnlyCollection<T> entities, CancellationTo
return returnOriginal ? original : entity;
}

public async Task<int> UpdateAllAsync<TFilter>(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
string jsonArrayFilterDefinition,
CancellationToken cancellationToken = default
)
{
return await UpdateAllAsync(filter, update, null, cancellationToken);
}

public async Task<int> UpdateAllAsync(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
UpdateOptions? updateOptions = null,
CancellationToken cancellationToken = default
)
{
Expand Down
92 changes: 66 additions & 26 deletions src/DataAccess/src/SIL.DataAccess/MemoryUpdateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,20 @@ public class MemoryUpdateBuilder<T>(Expression<Func<T, bool>> filter, T entity,

public IUpdateBuilder<T> Set<TField>(Expression<Func<T, TField>> field, TField value)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
prop.SetValue(owner, value, indices);
Set(_entity, _filter, field, value);
return this;
}

public IUpdateBuilder<T> SetOnInsert<TField>(Expression<Func<T, TField>> field, TField value)
{
if (_isInsert)
Set(field, value);
Set(_entity, _filter, field, value);
return this;
}

public IUpdateBuilder<T> Unset<TField>(Expression<Func<T, TField>> field)
{
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(_entity, _filter, field);
if (index != null)
{
// remove value from a dictionary
Expand All @@ -49,7 +46,7 @@ public IUpdateBuilder<T> Unset<TField>(Expression<Func<T, TField>> field)

public IUpdateBuilder<T> Inc(Expression<Func<T, int>> field, int value = 1)
{
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
{
Expand All @@ -62,20 +59,20 @@ public IUpdateBuilder<T> Inc(Expression<Func<T, int>> field, int value = 1)

public IUpdateBuilder<T> RemoveAll<TItem>(
Expression<Func<T, IEnumerable<TItem>?>> field,
Expression<Func<TItem, bool>> predicate
Expression<Func<TItem, bool>>? predicate = null
)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
Func<TItem, bool> predicateFunc = predicate.Compile();
Func<TItem, bool>? predicateFunc = predicate?.Compile();
foreach (object owner in owners)
{
var collection = (IEnumerable<TItem>?)prop.GetValue(owner, indices);
MethodInfo? removeMethod = collection?.GetType().GetMethod("Remove");
if (collection is not null && removeMethod is not null)
{
// the collection is mutable, so use Remove method to remove item
TItem[] toRemove = collection.Where(predicateFunc).ToArray();
TItem[] toRemove = collection.Where(i => predicateFunc?.Invoke(i) ?? true).ToArray();
foreach (TItem item in toRemove)
removeMethod.Invoke(collection, [item]);
}
Expand All @@ -84,14 +81,17 @@ Expression<Func<TItem, bool>> predicate
if (prop.PropertyType.IsArray || prop.PropertyType.IsInterface)
{
// the collection type is an array or interface, so construct a new array and set property
TItem[] newValue = collection.Where(i => !predicateFunc(i)).ToArray();
TItem[] newValue = collection.Where(i => !(predicateFunc?.Invoke(i) ?? false)).ToArray();
prop.SetValue(owner, newValue, indices);
}
else
{
// the collection type is a collection class, so construct a new collection and set property
var newValue = (IEnumerable<TItem>?)
Activator.CreateInstance(prop.PropertyType, collection.Where(i => !predicateFunc(i)).ToArray());
Activator.CreateInstance(
prop.PropertyType,
collection.Where(i => !(predicateFunc?.Invoke(i) ?? false)).ToArray()
);
prop.SetValue(owner, newValue, indices);
}
}
Expand All @@ -101,7 +101,7 @@ Expression<Func<TItem, bool>> predicate

public IUpdateBuilder<T> Remove<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
{
Expand Down Expand Up @@ -134,7 +134,7 @@ public IUpdateBuilder<T> Remove<TItem>(Expression<Func<T, IEnumerable<TItem>?>>

public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
{
Expand All @@ -147,7 +147,7 @@ public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> fie
}
else
{
collection ??= Array.Empty<TItem>();
collection ??= [];
if (prop.PropertyType.IsArray || prop.PropertyType.IsInterface)
{
// the collection type is an array or interface, so construct a new array and set property
Expand All @@ -166,6 +166,47 @@ public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> fie
return this;
}

public IUpdateBuilder<T> SetAll<TItem, TField>(
Expression<Func<T, IEnumerable<TItem>?>> collectionField,
Expression<Func<TItem, TField>> itemField,
TField value,
Expression<Func<TItem, bool>>? predicate = null
)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(
_entity,
_filter,
collectionField
);
object[]? indices = index == null ? null : [index];
Func<TItem, bool>? predicateFunc = predicate?.Compile();
foreach (object owner in owners)
{
var collection = (IEnumerable<TItem>?)prop.GetValue(owner, indices);
if (collection is null)
continue;
foreach (TItem item in collection)
{
if (predicateFunc == null || predicateFunc(item))
Set(item, i => true, itemField, value);
}
}
return this;
}

private static void Set<TEntity, TField>(
TEntity entity,
Expression<Func<TEntity, bool>> filter,
Expression<Func<TEntity, TField>> field,
TField value
)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(entity, filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
prop.SetValue(owner, value, indices);
}

private static bool IsAnyMethod(MethodInfo mi)
{
return mi.DeclaringType == typeof(Enumerable) && mi.Name == "Any";
Expand All @@ -180,8 +221,10 @@ private static MethodInfo GetFirstOrDefaultMethod(Type type)
.MakeGenericMethod(type);
}

private (IEnumerable<object> Owners, PropertyInfo Property, object? Index) GetFieldOwners<TField>(
Expression<Func<T, TField>> field
private static (IEnumerable<object> Owners, PropertyInfo Property, object? Index) GetFieldOwners<TEntity, TField>(
TEntity entity,
Expression<Func<TEntity, bool>> filter,
Expression<Func<TEntity, TField>> field
)
{
List<object>? owners = null;
Expand All @@ -192,8 +235,8 @@ Expression<Func<T, TField>> field
var newOwners = new List<object>();
if (owners == null)
{
if (_entity != null)
newOwners.Add(_entity);
if (entity != null)
newOwners.Add(entity);
}
else
{
Expand All @@ -206,17 +249,14 @@ Expression<Func<T, TField>> field
switch (index)
{
case ArrayPosition.FirstMatching:
foreach (Expression expression in ExpressionHelper.Flatten(_filter))
foreach (Expression expression in ExpressionHelper.Flatten(filter))
{
if (expression is MethodCallExpression callExpr && IsAnyMethod(callExpr.Method))
{
var predicate = (LambdaExpression)callExpr.Arguments[1];
Type itemType = predicate.Parameters[0].Type;
MethodInfo firstOrDefault = GetFirstOrDefaultMethod(itemType);
newOwner = firstOrDefault.Invoke(
null,
new object[] { owner, predicate.Compile() }
);
newOwner = firstOrDefault.Invoke(null, [owner, predicate.Compile()]);
if (newOwner != null)
newOwners.Add(newOwner);
break;
Expand Down Expand Up @@ -247,7 +287,7 @@ Expression<Func<T, TField>> field
}
else
{
newOwner = method.Invoke(owner, new object[] { index });
newOwner = method.Invoke(owner, [index]);
if (newOwner != null)
newOwners.Add(newOwner);
}
Expand Down
32 changes: 7 additions & 25 deletions src/DataAccess/src/SIL.DataAccess/MongoRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ await _collection
var updateBuilder = new MongoUpdateBuilder<T>();
update(updateBuilder);
updateBuilder.Inc(e => e.Revision, 1);
UpdateDefinition<T> updateDef = updateBuilder.Build();
(UpdateDefinition<T> updateDef, IReadOnlyList<ArrayFilterDefinition> arrayFilters) = updateBuilder.Build();
var options = new FindOneAndUpdateOptions<T>
{
IsUpsert = upsert,
Expand Down Expand Up @@ -151,50 +151,32 @@ await _collection
return entity;
}

public async Task<int> UpdateAllAsync<TFilter>(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
string jsonArrayFilterDefinition,
CancellationToken cancellationToken = default
)
{
var updateOptions = new UpdateOptions
{
ArrayFilters = [new JsonArrayFilterDefinition<TFilter>(jsonArrayFilterDefinition)]
};
return await UpdateAllAsync(filter, update, updateOptions, cancellationToken).ConfigureAwait(false);
}

public async Task<int> UpdateAllAsync(
Expression<Func<T, bool>> filter,
Action<IUpdateBuilder<T>> update,
UpdateOptions? updateOptions = null,
CancellationToken cancellationToken = default
)
{
var updateBuilder = new MongoUpdateBuilder<T>();
update(updateBuilder);
updateBuilder.Inc(e => e.Revision, 1);
UpdateDefinition<T> updateDef = updateBuilder.Build();
(UpdateDefinition<T> updateDef, IReadOnlyList<ArrayFilterDefinition> arrayFilters) = updateBuilder.Build();
UpdateOptions? updateOptions = null;
if (arrayFilters.Count > 0)
updateOptions = new UpdateOptions { ArrayFilters = arrayFilters };
UpdateResult result;
try
{
if (_context.Session is not null)
{
result = await _collection
.UpdateManyAsync(
_context.Session,
filter,
updateDef,
updateOptions,
cancellationToken: cancellationToken
)
.UpdateManyAsync(_context.Session, filter, updateDef, updateOptions, cancellationToken)
.ConfigureAwait(false);
}
else
{
result = await _collection
.UpdateManyAsync(filter, updateDef, updateOptions, cancellationToken: cancellationToken)
.UpdateManyAsync(filter, updateDef, updateOptions, cancellationToken)
.ConfigureAwait(false);
}
}
Expand Down
Loading

0 comments on commit 01b9694

Please sign in to comment.