diff --git a/src/Framework/Core/Controls/GridViewDataSetExtensions.cs b/src/Framework/Core/Controls/GridViewDataSetExtensions.cs index 18cd22cdc5..b707d22b42 100644 --- a/src/Framework/Core/Controls/GridViewDataSetExtensions.cs +++ b/src/Framework/Core/Controls/GridViewDataSetExtensions.cs @@ -1,5 +1,7 @@ using System; +using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace DotVVM.Framework.Controls @@ -35,6 +37,44 @@ public static void LoadFromQueryable(this IGridViewDataSet dataSet, IQuery dataSet.IsRefreshRequired = false; } + public static async Task LoadFromQueryableAsync(this IGridViewDataSet dataSet, IQueryable queryable, CancellationToken cancellationToken = default) + { + if (dataSet.FilteringOptions is not IApplyToQueryable filteringOptions) + { + throw new ArgumentException($"The FilteringOptions of {dataSet.GetType()} must implement IApplyToQueryable!"); + } + if (dataSet.SortingOptions is not IApplyToQueryable sortingOptions) + { + throw new ArgumentException($"The SortingOptions of {dataSet.GetType()} must implement IApplyToQueryable!"); + } + if (dataSet.PagingOptions is not IApplyToQueryable pagingOptions) + { + throw new ArgumentException($"The PagingOptions of {dataSet.GetType()} must implement IApplyToQueryable!"); + } + + var filtered = filteringOptions.ApplyToQueryable(queryable); + var sorted = sortingOptions.ApplyToQueryable(filtered); + var paged = pagingOptions.ApplyToQueryable(sorted); + if (paged is not IAsyncEnumerable asyncPaged) + { + throw new ArgumentException($"The specified IQueryable ({queryable.GetType().FullName}), does not support async enumeration. Please use the LoadFromQueryable method.", nameof(queryable)); + } + + var result = new List(); + await foreach (var item in asyncPaged.WithCancellation(cancellationToken)) + { + result.Add(item); + } + dataSet.Items = result; + + if (pagingOptions is IPagingOptionsLoadingPostProcessor pagingOptionsLoadingPostProcessor) + { + await pagingOptionsLoadingPostProcessor.ProcessLoadedItemsAsync(filtered, result, cancellationToken); + } + + dataSet.IsRefreshRequired = false; + } + public static void GoToFirstPageAndRefresh(this IPageableGridViewDataSet dataSet) { dataSet.PagingOptions.GoToFirstPage(); diff --git a/src/Framework/Core/Controls/Options/IPagingOptionsLoadingPostProcessor.cs b/src/Framework/Core/Controls/Options/IPagingOptionsLoadingPostProcessor.cs index 6356489788..e298922cb6 100644 --- a/src/Framework/Core/Controls/Options/IPagingOptionsLoadingPostProcessor.cs +++ b/src/Framework/Core/Controls/Options/IPagingOptionsLoadingPostProcessor.cs @@ -1,5 +1,7 @@ using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace DotVVM.Framework.Controls; @@ -7,4 +9,5 @@ namespace DotVVM.Framework.Controls; public interface IPagingOptionsLoadingPostProcessor { void ProcessLoadedItems(IQueryable filteredQueryable, IList items); + Task ProcessLoadedItemsAsync(IQueryable filteredQueryable, IList items, CancellationToken cancellationToken); } diff --git a/src/Framework/Core/Controls/Options/PagingImplementation.cs b/src/Framework/Core/Controls/Options/PagingImplementation.cs index 60a9d77851..75cf1bcef2 100644 --- a/src/Framework/Core/Controls/Options/PagingImplementation.cs +++ b/src/Framework/Core/Controls/Options/PagingImplementation.cs @@ -1,9 +1,15 @@ -using System.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; namespace DotVVM.Framework.Controls { public static class PagingImplementation { + public static Func>? CustomAsyncQueryableCountDelegate; /// /// Applies paging to the after the total number @@ -17,6 +23,87 @@ public static IQueryable ApplyPagingToQueryable(IQueryable ? queryable.Skip(options.PageSize * options.PageIndex).Take(options.PageSize) : queryable; } - } + /// Attempts to count the queryable asynchronously. EF Core IQueryables are supported, and IQueryables which return IAsyncEnumerable from GroupBy operator also work correctly. Otherwise, a synchronous fallback is user, or may be set to add support for an ORM mapper of choice. + public static async Task QueryableAsyncCount(IQueryable queryable, CancellationToken ct = default) + { + if (CustomAsyncQueryableCountDelegate is {} customDelegate) + { + var result = await customDelegate(queryable, ct); + if (result.HasValue) + { + return result.Value; + } + } + + var queryableType = queryable.GetType(); + return await ( + EfCoreAsyncCountHack(queryable, queryableType, ct) ?? // TODO: test this + Ef6AsyncCountHack(queryable, ct) ?? // TODO: test this + StandardAsyncCountHack(queryable, ct) + ); + } + + static MethodInfo? efMethodCache; + static Task? EfCoreAsyncCountHack(IQueryable queryable, Type queryableType, CancellationToken ct) + { + if (!( + queryableType.Namespace == "Microsoft.EntityFrameworkCore.Query.Internal" && queryableType.Name == "EntityQueryable`1" || + queryableType.Namespace == "Microsoft.EntityFrameworkCore.Internal" && queryableType.Name == "InternalDbSet`1" + )) + return null; + + var countMethod = efMethodCache ?? queryableType.Assembly.GetType("Microsoft.EntityFrameworkCore.EntityFrameworkQueryableExtensions")!.GetMethods().SingleOrDefault(m => m.Name == "CountAsync" && m.GetParameters() is { Length: 2 } parameters && parameters[1].ParameterType == typeof(CancellationToken)); + if (countMethod is null) + return null; + + if (efMethodCache is null) + Interlocked.CompareExchange(ref efMethodCache, countMethod, null); + + var countMethodGeneric = countMethod.MakeGenericMethod(typeof(T)); + return (Task)countMethodGeneric.Invoke(null, new object[] { queryable, ct })!; + } + + static readonly Type? ef6IDbAsyncQueryProvider = Type.GetType("System.Data.Entity.Infrastructure.IDbAsyncQueryProvider, EntityFramework"); // https://learn.microsoft.com/en-us/dotnet/api/system.data.entity.infrastructure.idbasyncqueryprovider?view=entity-framework-6.2.0 + static MethodInfo? ef6MethodCache; + static Task? Ef6AsyncCountHack(IQueryable queryable, CancellationToken ct) + { + if (ef6IDbAsyncQueryProvider is null) + return null; + if (!ef6IDbAsyncQueryProvider.IsInstanceOfType(queryable.Provider)) + return null; + + var countMethod = ef6MethodCache ?? Type.GetType("System.Data.Entity.QueryableExtensions, EntityFramework")!.GetMethods().SingleOrDefault(m => m.Name == "CountAsync" && m.GetParameters() is { Length: 2 } parameters && parameters[1].ParameterType == typeof(CancellationToken))!; + if (countMethod is null) + return null; + + if (ef6MethodCache is null) + Interlocked.CompareExchange(ref ef6MethodCache, countMethod, null); + + var countMethodGeneric = countMethod.MakeGenericMethod(typeof(T)); + return (Task)countMethodGeneric.Invoke(null, new object[] { queryable, ct })!; + } + + static Task StandardAsyncCountHack(IQueryable queryable, CancellationToken ct) + { +#if NETSTANDARD2_1_OR_GREATER + var countGroupHack = queryable.GroupBy(_ => 1).Select(group => group.Count()); + // if not IAsyncEnumerable, just use synchronous Count + if (countGroupHack is not IAsyncEnumerable countGroupEnumerable) + { + return Task.FromResult(queryable.Count()); + } + + return FirstOrDefaultAsync(countGroupEnumerable, ct); + } + + static async Task FirstOrDefaultAsync(IAsyncEnumerable enumerable, CancellationToken ct) + { + await using var enumerator = enumerable.GetAsyncEnumerator(ct); + return await enumerator.MoveNextAsync() ? enumerator.Current : default; +#else + throw new Exception("IAsyncEnumerable is not supported on .NET Framework and the queryable does not support EntityFramework CountAsync."); +#endif + } + } } diff --git a/src/Framework/Core/Controls/Options/PagingOptions.cs b/src/Framework/Core/Controls/Options/PagingOptions.cs index 83e009520e..5957db0705 100644 --- a/src/Framework/Core/Controls/Options/PagingOptions.cs +++ b/src/Framework/Core/Controls/Options/PagingOptions.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using DotVVM.Framework.ViewModel; namespace DotVVM.Framework.Controls @@ -125,5 +127,10 @@ public virtual void ProcessLoadedItems(IQueryable filteredQueryable, IList { TotalItemsCount = filteredQueryable.Count(); } + public async Task ProcessLoadedItemsAsync(IQueryable filteredQueryable, IList items, CancellationToken cancellationToken) + { + TotalItemsCount = await PagingImplementation.QueryableAsyncCount(filteredQueryable, cancellationToken); + } + } } diff --git a/src/Samples/Common/ViewModels/ControlSamples/GridView/GridViewStaticCommandViewModel.cs b/src/Samples/Common/ViewModels/ControlSamples/GridView/GridViewStaticCommandViewModel.cs index a41ce1b149..b2b06ca49d 100644 --- a/src/Samples/Common/ViewModels/ControlSamples/GridView/GridViewStaticCommandViewModel.cs +++ b/src/Samples/Common/ViewModels/ControlSamples/GridView/GridViewStaticCommandViewModel.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using DotVVM.Framework.Controls; using DotVVM.Framework.ViewModel; @@ -123,6 +124,12 @@ public void ProcessLoadedItems(IQueryable filteredQueryable, IList item NextPageToken = lastToken.ToString(); } } + + Task IPagingOptionsLoadingPostProcessor.ProcessLoadedItemsAsync(IQueryable filteredQueryable, IList items, CancellationToken cancellationToken) + { + ProcessLoadedItems(filteredQueryable, items); + return Task.CompletedTask; + } } public class NextTokenHistoryGridViewDataSet() : GenericGridViewDataSet, RowEditOptions>( @@ -165,6 +172,12 @@ public void ProcessLoadedItems(IQueryable filteredQueryable, IList item TokenHistory.Add((lastToken ?? 0).ToString()); } } + + Task IPagingOptionsLoadingPostProcessor.ProcessLoadedItemsAsync(IQueryable filteredQueryable, IList items, CancellationToken cancellationToken) + { + ProcessLoadedItems(filteredQueryable, items); + return Task.CompletedTask; + } } public class MultiSortGridViewDataSet() : GenericGridViewDataSet, RowEditOptions>( diff --git a/src/Tests/DotVVM.Framework.Tests.csproj b/src/Tests/DotVVM.Framework.Tests.csproj index 13ad10a6e9..aa755a5de6 100644 --- a/src/Tests/DotVVM.Framework.Tests.csproj +++ b/src/Tests/DotVVM.Framework.Tests.csproj @@ -42,6 +42,8 @@ + + diff --git a/src/Tests/ViewModel/EFCoreGridViewDataSetTests.cs b/src/Tests/ViewModel/EFCoreGridViewDataSetTests.cs new file mode 100644 index 0000000000..b61a14592a --- /dev/null +++ b/src/Tests/ViewModel/EFCoreGridViewDataSetTests.cs @@ -0,0 +1,139 @@ +#if NET8_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using DotVVM.Framework.Controls; +using DotVVM.Framework.ViewModel; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace DotVVM.Framework.Tests.ViewModel +{ + [TestClass] + public class EFCoreGridViewDataSetTests + { + private readonly DbContextOptions contextOptions; + + public EFCoreGridViewDataSetTests() + { + contextOptions = new DbContextOptionsBuilder() + .UseInMemoryDatabase("BloggingControllerTest") + .ConfigureWarnings(b => b.Ignore(InMemoryEventId.TransactionIgnoredWarning)) + .Options; + } + + class MyDbContext: DbContext + { + public MyDbContext(DbContextOptions options) : base(options) + { + } + + public DbSet Entries { get; set; } + } + + record Entry(int Id, string Name, int SomethingElse = 0); + + MyDbContext Init() + { + var context = new MyDbContext(contextOptions); + context.Database.EnsureDeleted(); + context.Database.EnsureCreated(); + context.Entries.AddRange([ + new (1, "Z"), + new (2, "Y"), + new (3, "X"), + new (4, "W"), + new (5, "V"), + new (6, "U", 5), + new (7, "T", 5), + new (8, "S", 5), + new (9, "R", 3), + new (10, "Q", 3), + ]); + context.SaveChanges(); + return context; + } + + [TestMethod] + public void LoadData_PagingSorting() + { + using var context = Init(); + + var dataSet = new GridViewDataSet() + { + PagingOptions = { PageSize = 3, PageIndex = 0 }, + SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false }, + }; + + dataSet.LoadFromQueryable(context.Entries); + + Assert.AreEqual(3, dataSet.Items.Count); + Assert.AreEqual(10, dataSet.PagingOptions.TotalItemsCount); + Assert.AreEqual(10, dataSet.Items[0].Id); + Assert.AreEqual(9, dataSet.Items[1].Id); + Assert.AreEqual(8, dataSet.Items[2].Id); + } + + [TestMethod] + public void LoadData_PagingSorting_PreFiltered() + { + using var context = Init(); + + var dataSet = new GridViewDataSet() + { + PagingOptions = { PageSize = 3, PageIndex = 0 }, + SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false }, + }; + + dataSet.LoadFromQueryable(context.Entries.Where(e => e.SomethingElse == 3)); + + Assert.AreEqual(2, dataSet.Items.Count); + Assert.AreEqual(2, dataSet.PagingOptions.TotalItemsCount); + Assert.AreEqual(10, dataSet.Items[0].Id); + Assert.AreEqual(9, dataSet.Items[1].Id); + } + + [TestMethod] + public async Task LoadData_PagingSortingAsync() + { + using var context = Init(); + + var dataSet = new GridViewDataSet() + { + PagingOptions = { PageSize = 3, PageIndex = 0 }, + SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false }, + }; + + await dataSet.LoadFromQueryableAsync(context.Entries); + + Assert.AreEqual(3, dataSet.Items.Count); + Assert.AreEqual(10, dataSet.PagingOptions.TotalItemsCount); + Assert.AreEqual(10, dataSet.Items[0].Id); + Assert.AreEqual(9, dataSet.Items[1].Id); + Assert.AreEqual(8, dataSet.Items[2].Id); + } + + [TestMethod] + public async Task LoadData_PagingSorting_PreFilteredAsync() + { + using var context = Init(); + + var dataSet = new GridViewDataSet() + { + PagingOptions = { PageSize = 3, PageIndex = 0 }, + SortingOptions = { SortExpression = nameof(Entry.Name), SortDescending = false }, + }; + + await dataSet.LoadFromQueryableAsync(context.Entries.Where(e => e.SomethingElse == 3)); + + Assert.AreEqual(2, dataSet.Items.Count); + Assert.AreEqual(2, dataSet.PagingOptions.TotalItemsCount); + Assert.AreEqual(10, dataSet.Items[0].Id); + Assert.AreEqual(9, dataSet.Items[1].Id); + } + } +} +#endif