Skip to content

Commit

Permalink
Add async LoadFromQueryable
Browse files Browse the repository at this point in the history
It is primarily based on the IAsyncEnumerable interface,
but requires some hacks to get access to the CountAsync method.
  • Loading branch information
exyi committed Sep 7, 2024
1 parent 2bbf954 commit 94d99b9
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 1 deletion.
40 changes: 40 additions & 0 deletions src/Framework/Core/Controls/GridViewDataSetExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace DotVVM.Framework.Controls
Expand Down Expand Up @@ -35,6 +37,44 @@ public static void LoadFromQueryable<T>(this IGridViewDataSet<T> dataSet, IQuery
dataSet.IsRefreshRequired = false;
}

public static async Task LoadFromQueryableAsync<T>(this IGridViewDataSet<T> dataSet, IQueryable<T> queryable, CancellationToken cancellationToken = default)
{
if (dataSet.FilteringOptions is not IApplyToQueryable filteringOptions)
{
throw new ArgumentException($"The FilteringOptions of {dataSet.GetType()} must implement IApplyToQueryable!");
}
if (dataSet.SortingOptions is not IApplyToQueryable sortingOptions)
{
throw new ArgumentException($"The SortingOptions of {dataSet.GetType()} must implement IApplyToQueryable!");
}
if (dataSet.PagingOptions is not IApplyToQueryable pagingOptions)
{
throw new ArgumentException($"The PagingOptions of {dataSet.GetType()} must implement IApplyToQueryable!");
}

var filtered = filteringOptions.ApplyToQueryable(queryable);
var sorted = sortingOptions.ApplyToQueryable(filtered);
var paged = pagingOptions.ApplyToQueryable(sorted);
if (paged is not IAsyncEnumerable<T> asyncPaged)
{
throw new ArgumentException($"The specified IQueryable ({queryable.GetType().FullName}), does not support async enumeration. Please use the LoadFromQueryable method.", nameof(queryable));
}

var result = new List<T>();
await foreach (var item in asyncPaged.WithCancellation(cancellationToken))
{
result.Add(item);
}
dataSet.Items = result;

if (pagingOptions is IPagingOptionsLoadingPostProcessor pagingOptionsLoadingPostProcessor)
{
await pagingOptionsLoadingPostProcessor.ProcessLoadedItemsAsync(filtered, result, cancellationToken);
}

dataSet.IsRefreshRequired = false;
}

public static void GoToFirstPageAndRefresh(this IPageableGridViewDataSet<IPagingFirstPageCapability> dataSet)
{
dataSet.PagingOptions.GoToFirstPage();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace DotVVM.Framework.Controls;

/// <summary> Provides an extension point to the <see cref="GridViewDataSetExtensions.LoadFromQueryable{T}(DotVVM.Framework.Controls.IGridViewDataSet{T}, IQueryable{T})" /> method, which is invoked after the items are loaded from database. </summary>
public interface IPagingOptionsLoadingPostProcessor
{
void ProcessLoadedItems<T>(IQueryable<T> filteredQueryable, IList<T> items);
Task ProcessLoadedItemsAsync<T>(IQueryable<T> filteredQueryable, IList<T> items, CancellationToken cancellationToken);
}
86 changes: 85 additions & 1 deletion src/Framework/Core/Controls/Options/PagingImplementation.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
using System.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;

namespace DotVVM.Framework.Controls
{
public static class PagingImplementation
{
public static Func<IQueryable, CancellationToken, Task<int?>>? CustomAsyncQueryableCountDelegate;

/// <summary>
/// Applies paging to the <paramref name="queryable" /> after the total number
Expand All @@ -17,6 +23,84 @@ public static IQueryable<T> ApplyPagingToQueryable<T, TPagingOptions>(IQueryable
? queryable.Skip(options.PageSize * options.PageIndex).Take(options.PageSize)
: queryable;
}

/// <summary> Attempts to count the queryable asynchronously. EF Core IQueryables are supported, and IQueryables which return IAsyncEnumerable from GroupBy operator also work correctly. Otherwise, a synchronous fallback is user, or <see cref="CustomAsyncQueryableCountDelegate" /> may be set to add support for an ORM mapper of choice. </summary>
public static async Task<int> QueryableAsyncCount<T>(IQueryable<T> queryable, CancellationToken ct = default)
{
if (CustomAsyncQueryableCountDelegate is {} customDelegate)
{
var result = await customDelegate(queryable, ct);
if (result.HasValue)
{
return result.Value;
}
}

var queryableType = queryable.GetType();
return await (
EfCoreAsyncCountHack(queryable, queryableType, ct) ?? // TODO: test this
Ef6AsyncCountHack(queryable, ct) ?? // TODO: test this
StandardAsyncCountHack(queryable, ct)
);
}

static MethodInfo? efMethodCache;
static Task<int>? EfCoreAsyncCountHack<T>(IQueryable<T> queryable, Type queryableType, CancellationToken ct)
{
if (!(
queryableType.Namespace == "Microsoft.EntityFrameworkCore.Query.Internal" && queryableType.Name == "EntityQueryable`1" ||
queryableType.Namespace == "Microsoft.EntityFrameworkCore.Internal" && queryableType.Name == "InternalDbSet`1"
))
return null;

var countMethod = efMethodCache ?? queryableType.Assembly.GetType("Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions")!.GetMethods().SingleOrDefault(m => m.Name == "CountAsync" && m.GetParameters() is { Length: 2 } parameters && parameters[1].ParameterType == typeof(CancellationToken));
if (countMethod is null)
return null;

if (efMethodCache is null)
Interlocked.CompareExchange(ref efMethodCache, countMethod, null);

var countMethodGeneric = countMethod.MakeGenericMethod(typeof(T));
return (Task<int>)countMethodGeneric.Invoke(null, new object[] { queryable, ct })!;
}

static readonly Type? ef6IDbAsyncQueryProvider = Type.GetType("System.Data.Entity.Infrastructure.IDbAsyncQueryProvider, EntityFramework"); // https://learn.microsoft.com/en-us/dotnet/api/system.data.entity.infrastructure.idbasyncqueryprovider?view=entity-framework-6.2.0
static MethodInfo? ef6MethodCache;
static Task<int>? Ef6AsyncCountHack<T>(IQueryable<T> queryable, CancellationToken ct)
{
if (ef6IDbAsyncQueryProvider is null)
return null;
if (!ef6IDbAsyncQueryProvider.IsInstanceOfType(queryable.Provider))
return null;

var countMethod = ef6MethodCache ?? Type.GetType("System.Data.Entity.QueryableExtensions, EntityFramework")!.GetMethods().SingleOrDefault(m => m.Name == "CountAsync" && m.GetParameters() is { Length: 2 } parameters && parameters[1].ParameterType == typeof(CancellationToken))!;
if (countMethod is null)
return null;

if (ef6MethodCache is null)
Interlocked.CompareExchange(ref ef6MethodCache, countMethod, null);

var countMethodGeneric = countMethod.MakeGenericMethod(typeof(T));
return (Task<int>)countMethodGeneric.Invoke(null, new object[] { queryable, ct })!;
}

static Task<int> StandardAsyncCountHack<T>(IQueryable<T> queryable, CancellationToken ct)
{
var countGroupHack = queryable.GroupBy(_ => 1).Select(group => group.Count());
// if not IAsyncEnumerable, just use synchronous Count
if (countGroupHack is not IAsyncEnumerable<int> countGroupEnumerable)
{
return Task.FromResult(queryable.Count());
}

return FirstOrDefaultAsync(countGroupEnumerable, ct);
}

static async Task<T?> FirstOrDefaultAsync<T>(IAsyncEnumerable<T> enumerable, CancellationToken ct)

Check failure on line 99 in src/Framework/Core/Controls/Options/PagingImplementation.cs

View workflow job for this annotation

GitHub Actions / Build all projects without errors

The type or namespace name 'IAsyncEnumerable<>' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 99 in src/Framework/Core/Controls/Options/PagingImplementation.cs

View workflow job for this annotation

GitHub Actions / Build all projects without errors

The type or namespace name 'IAsyncEnumerable<>' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 99 in src/Framework/Core/Controls/Options/PagingImplementation.cs

View workflow job for this annotation

GitHub Actions / .NET unit tests (windows-2022)

The type or namespace name 'IAsyncEnumerable<>' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 99 in src/Framework/Core/Controls/Options/PagingImplementation.cs

View workflow job for this annotation

GitHub Actions / .NET unit tests (windows-2022)

The type or namespace name 'IAsyncEnumerable<>' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 99 in src/Framework/Core/Controls/Options/PagingImplementation.cs

View workflow job for this annotation

GitHub Actions / UI tests (chrome, windows-2022, Development, Default)

The type or namespace name 'IAsyncEnumerable<>' could not be found (are you missing a using directive or an assembly reference?)
{
await using var enumerator = enumerable.GetAsyncEnumerator(ct);
return await enumerator.MoveNextAsync() ? enumerator.Current : default;
}
}

}
7 changes: 7 additions & 0 deletions src/Framework/Core/Controls/Options/PagingOptions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DotVVM.Framework.ViewModel;

namespace DotVVM.Framework.Controls
Expand Down Expand Up @@ -125,5 +127,10 @@ public virtual void ProcessLoadedItems<T>(IQueryable<T> filteredQueryable, IList
{
TotalItemsCount = filteredQueryable.Count();
}
public async Task ProcessLoadedItemsAsync<T>(IQueryable<T> filteredQueryable, IList<T> items, CancellationToken cancellationToken)
{
TotalItemsCount = await PagingImplementation.QueryableAsyncCount(filteredQueryable, cancellationToken);
}

}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DotVVM.Framework.Controls;
using DotVVM.Framework.ViewModel;
Expand Down Expand Up @@ -123,6 +124,12 @@ public void ProcessLoadedItems<T>(IQueryable<T> filteredQueryable, IList<T> item
NextPageToken = lastToken.ToString();
}
}

Task IPagingOptionsLoadingPostProcessor.ProcessLoadedItemsAsync<T>(IQueryable<T> filteredQueryable, IList<T> items, CancellationToken cancellationToken)
{
ProcessLoadedItems(filteredQueryable, items);
return Task.CompletedTask;
}
}

public class NextTokenHistoryGridViewDataSet() : GenericGridViewDataSet<CustomerData, NoFilteringOptions, SortingOptions, CustomerDataNextTokenHistoryPagingOptions, RowInsertOptions<CustomerData>, RowEditOptions>(
Expand Down Expand Up @@ -165,6 +172,12 @@ public void ProcessLoadedItems<T>(IQueryable<T> filteredQueryable, IList<T> item
TokenHistory.Add((lastToken ?? 0).ToString());
}
}

Task IPagingOptionsLoadingPostProcessor.ProcessLoadedItemsAsync<T>(IQueryable<T> filteredQueryable, IList<T> items, CancellationToken cancellationToken)
{
ProcessLoadedItems(filteredQueryable, items);
return Task.CompletedTask;
}
}

public class MultiSortGridViewDataSet() : GenericGridViewDataSet<CustomerData, NoFilteringOptions, MultiCriteriaSortingOptions, PagingOptions, RowInsertOptions<CustomerData>, RowEditOptions>(
Expand Down
2 changes: 2 additions & 0 deletions src/Tests/DotVVM.Framework.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net8.0'">
<ProjectReference Include="../Framework/Hosting.AspNetCore/DotVVM.Framework.Hosting.AspNetCore.csproj" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="8.0.8" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.8" />
<PackageReference Include="CheckTestOutput" Version="0.6.3" />
</ItemGroup>
<ItemGroup>
Expand Down
139 changes: 139 additions & 0 deletions src/Tests/ViewModel/EFCoreGridViewDataSetTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#if NET8_0_OR_GREATER
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DotVVM.Framework.Controls;
using DotVVM.Framework.ViewModel;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace DotVVM.Framework.Tests.ViewModel
{
[TestClass]
public class EFCoreGridViewDataSetTests
{
private readonly DbContextOptions<MyDbContext> contextOptions;

public EFCoreGridViewDataSetTests()
{
contextOptions = new DbContextOptionsBuilder<MyDbContext>()
.UseInMemoryDatabase("BloggingControllerTest")
.ConfigureWarnings(b => b.Ignore(InMemoryEventId.TransactionIgnoredWarning))
.Options;
}

class MyDbContext: DbContext
{
public MyDbContext(DbContextOptions options) : base(options)
{
}

public DbSet<Entry> Entries { get; set; }
}

record Entry(int Id, string Name, int SomethingElse = 0);

MyDbContext Init()
{
var context = new MyDbContext(contextOptions);
context.Database.EnsureDeleted();
context.Database.EnsureCreated();
context.Entries.AddRange([
new (1, "Z"),
new (2, "Y"),
new (3, "X"),
new (4, "W"),
new (5, "V"),
new (6, "U", 5),
new (7, "T", 5),
new (8, "S", 5),
new (9, "R", 3),
new (10, "Q", 3),
]);
context.SaveChanges();
return context;
}

[TestMethod]
public void LoadData_PagingSorting()
{
using var context = Init();

var dataSet = new GridViewDataSet<Entry>()
{
PagingOptions = { PageSize = 3, PageIndex = 0 },
SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false },
};

dataSet.LoadFromQueryable(context.Entries);

Assert.AreEqual(3, dataSet.Items.Count);
Assert.AreEqual(10, dataSet.PagingOptions.TotalItemsCount);
Assert.AreEqual(10, dataSet.Items[0].Id);
Assert.AreEqual(9, dataSet.Items[1].Id);
Assert.AreEqual(8, dataSet.Items[2].Id);
}

[TestMethod]
public void LoadData_PagingSorting_PreFiltered()
{
using var context = Init();

var dataSet = new GridViewDataSet<Entry>()
{
PagingOptions = { PageSize = 3, PageIndex = 0 },
SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false },
};

dataSet.LoadFromQueryable(context.Entries.Where(e => e.SomethingElse == 3));

Assert.AreEqual(2, dataSet.Items.Count);
Assert.AreEqual(2, dataSet.PagingOptions.TotalItemsCount);
Assert.AreEqual(10, dataSet.Items[0].Id);
Assert.AreEqual(9, dataSet.Items[1].Id);
}

[TestMethod]
public async Task LoadData_PagingSortingAsync()
{
using var context = Init();

var dataSet = new GridViewDataSet<Entry>()
{
PagingOptions = { PageSize = 3, PageIndex = 0 },
SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false },
};

await dataSet.LoadFromQueryableAsync(context.Entries);

Assert.AreEqual(3, dataSet.Items.Count);
Assert.AreEqual(10, dataSet.PagingOptions.TotalItemsCount);
Assert.AreEqual(10, dataSet.Items[0].Id);
Assert.AreEqual(9, dataSet.Items[1].Id);
Assert.AreEqual(8, dataSet.Items[2].Id);
}

[TestMethod]
public async Task LoadData_PagingSorting_PreFilteredAsync()
{
using var context = Init();

var dataSet = new GridViewDataSet<Entry>()
{
PagingOptions = { PageSize = 3, PageIndex = 0 },
SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false },
};

await dataSet.LoadFromQueryableAsync(context.Entries.Where(e => e.SomethingElse == 3));

Assert.AreEqual(2, dataSet.Items.Count);
Assert.AreEqual(2, dataSet.PagingOptions.TotalItemsCount);
Assert.AreEqual(10, dataSet.Items[0].Id);
Assert.AreEqual(9, dataSet.Items[1].Id);
}
}
}
#endif

0 comments on commit 94d99b9

Please sign in to comment.