From 3298fd104e76a1ef35d429ea2af48f3dfac8be1a Mon Sep 17 00:00:00 2001 From: Tobias Grimm Date: Sun, 19 May 2024 18:12:15 +0200 Subject: [PATCH] Make ToAsyncEnumerable() set the QueryStatistics.TotalResult Note: TotalResult will not be available before enumeration of the result started! --- src/LinqTests/Operators/async_enumerable.cs | 27 +++++++++++++++++++++ src/Marten/Linq/MartenLinqQueryProvider.cs | 8 ++++-- src/Marten/Linq/MartenLinqQueryable.cs | 2 +- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/LinqTests/Operators/async_enumerable.cs b/src/LinqTests/Operators/async_enumerable.cs index 155ad51c8d..21067bea7f 100644 --- a/src/LinqTests/Operators/async_enumerable.cs +++ b/src/LinqTests/Operators/async_enumerable.cs @@ -41,4 +41,31 @@ public async Task query_to_async_enumerable() } #endregion + + [Fact] + public async Task query_to_async_enumerable_with_query_statistics() + { + var targets = Target.GenerateRandomData(20).ToArray(); + await theStore.BulkInsertAsync(targets); + + var ids = new List(); + + var results = theSession.Query() + .Stats(out var stats) + .ToAsyncEnumerable(); + + stats.TotalResults.ShouldBe(0); + + await foreach (var target in results) + { + stats.TotalResults.ShouldBe(20); + ids.Add(target.Id); + } + + ids.Count.ShouldBe(20); + foreach (var target in targets) + { + ids.ShouldContain(target.Id); + } + } } diff --git a/src/Marten/Linq/MartenLinqQueryProvider.cs b/src/Marten/Linq/MartenLinqQueryProvider.cs index 7c08f1dc7c..89184f00bf 100644 --- a/src/Marten/Linq/MartenLinqQueryProvider.cs +++ b/src/Marten/Linq/MartenLinqQueryProvider.cs @@ -8,7 +8,6 @@ using System.Threading.Tasks; using Marten.Exceptions; using Marten.Internal.Sessions; -using Marten.Linq.Includes; using Marten.Linq.Parsing; using Marten.Linq.QueryHandlers; using Marten.Linq.Selectors; @@ -153,7 +152,7 @@ public T ExecuteHandler(IQueryHandler handler) public async IAsyncEnumerable ExecuteAsyncEnumerable(Expression expression, - [EnumeratorCancellation] CancellationToken token) + MartenLinqQueryProvider martenProvider, [EnumeratorCancellation] CancellationToken token) { var parser = new LinqQueryParser(this, _session, expression); var statements = parser.BuildStatements(); @@ -166,8 +165,13 @@ public async IAsyncEnumerable ExecuteAsyncEnumerable(Expression expression var cmd = _session.BuildCommand(statement); await using var reader = await _session.ExecuteReaderAsync(cmd, token).ConfigureAwait(false); + var totalRowsColumnIndex = martenProvider.Statistics != null ? reader.GetOrdinal("total_rows") : -1; while (await reader.ReadAsync(token).ConfigureAwait(false)) { + if (martenProvider.Statistics != null) + { + martenProvider.Statistics.TotalResults = await reader.GetFieldValueAsync(totalRowsColumnIndex, token).ConfigureAwait(false); + } yield return await selector.ResolveAsync(reader, token).ConfigureAwait(false); } } diff --git a/src/Marten/Linq/MartenLinqQueryable.cs b/src/Marten/Linq/MartenLinqQueryable.cs index ac20191aea..03a9fbd678 100644 --- a/src/Marten/Linq/MartenLinqQueryable.cs +++ b/src/Marten/Linq/MartenLinqQueryable.cs @@ -150,7 +150,7 @@ public async Task> ToListAsync(CancellationToken public IAsyncEnumerable ToAsyncEnumerable(CancellationToken token = default) { - return MartenProvider.ExecuteAsyncEnumerable(Expression, token); + return MartenProvider.ExecuteAsyncEnumerable(Expression, MartenProvider, token); } public Task AnyAsync(CancellationToken token)