Skip to content

Commit

Permalink
Make ToAsyncEnumerable() set the QueryStatistics.TotalResult
Browse files Browse the repository at this point in the history
Note: TotalResult will not be available before enumeration of
the result started!
  • Loading branch information
e-tobi authored and jeremydmiller committed May 20, 2024
1 parent 3d612cb commit 3298fd1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
27 changes: 27 additions & 0 deletions src/LinqTests/Operators/async_enumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,31 @@ public async Task query_to_async_enumerable()
}

#endregion

[Fact]
public async Task query_to_async_enumerable_with_query_statistics()
{
var targets = Target.GenerateRandomData(20).ToArray();
await theStore.BulkInsertAsync(targets);

var ids = new List<Guid>();

var results = theSession.Query<Target>()
.Stats(out var stats)
.ToAsyncEnumerable();

stats.TotalResults.ShouldBe(0);

await foreach (var target in results)
{
stats.TotalResults.ShouldBe(20);
ids.Add(target.Id);
}

ids.Count.ShouldBe(20);
foreach (var target in targets)
{
ids.ShouldContain(target.Id);
}
}
}
8 changes: 6 additions & 2 deletions src/Marten/Linq/MartenLinqQueryProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Threading.Tasks;
using Marten.Exceptions;
using Marten.Internal.Sessions;
using Marten.Linq.Includes;
using Marten.Linq.Parsing;
using Marten.Linq.QueryHandlers;
using Marten.Linq.Selectors;
Expand Down Expand Up @@ -153,7 +152,7 @@ public T ExecuteHandler<T>(IQueryHandler<T> handler)


public async IAsyncEnumerable<T> ExecuteAsyncEnumerable<T>(Expression expression,
[EnumeratorCancellation] CancellationToken token)
MartenLinqQueryProvider martenProvider, [EnumeratorCancellation] CancellationToken token)
{
var parser = new LinqQueryParser(this, _session, expression);
var statements = parser.BuildStatements();
Expand All @@ -166,8 +165,13 @@ public async IAsyncEnumerable<T> ExecuteAsyncEnumerable<T>(Expression expression
var cmd = _session.BuildCommand(statement);

await using var reader = await _session.ExecuteReaderAsync(cmd, token).ConfigureAwait(false);
var totalRowsColumnIndex = martenProvider.Statistics != null ? reader.GetOrdinal("total_rows") : -1;
while (await reader.ReadAsync(token).ConfigureAwait(false))
{
if (martenProvider.Statistics != null)
{
martenProvider.Statistics.TotalResults = await reader.GetFieldValueAsync<int>(totalRowsColumnIndex, token).ConfigureAwait(false);
}
yield return await selector.ResolveAsync(reader, token).ConfigureAwait(false);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/Marten/Linq/MartenLinqQueryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public async Task<IReadOnlyList<TResult>> ToListAsync<TResult>(CancellationToken

public IAsyncEnumerable<T> ToAsyncEnumerable(CancellationToken token = default)
{
return MartenProvider.ExecuteAsyncEnumerable<T>(Expression, token);
return MartenProvider.ExecuteAsyncEnumerable<T>(Expression, MartenProvider, token);
}

public Task<bool> AnyAsync(CancellationToken token)
Expand Down

0 comments on commit 3298fd1

Please sign in to comment.