From 6659cb686d36c8ab3361e0a45e9282a149e35586 Mon Sep 17 00:00:00 2001 From: neuecc Date: Sat, 2 Mar 2024 17:36:19 +0900 Subject: [PATCH] Add TakeUntil(async), SkipUntil(async) --- src/R3/Operators/Skip.cs | 4 +- src/R3/Operators/SkipLast.cs | 4 +- src/R3/Operators/SkipUntil.cs | 72 +++++++++++++++++++ src/R3/Operators/Take.cs | 4 +- src/R3/Operators/TakeLast.cs | 4 +- src/R3/Operators/TakeUntil.cs | 68 ++++++++++++++++++ tests/R3.Tests/OperatorTests/SkipUntilTest.cs | 24 +++++++ tests/R3.Tests/OperatorTests/TakeUntilTest.cs | 21 ++++++ 8 files changed, 193 insertions(+), 8 deletions(-) diff --git a/src/R3/Operators/Skip.cs b/src/R3/Operators/Skip.cs index 3ca6f89c..2dae31ac 100644 --- a/src/R3/Operators/Skip.cs +++ b/src/R3/Operators/Skip.cs @@ -18,7 +18,7 @@ public static Observable Skip(this Observable source, TimeSpan duration public static Observable Skip(this Observable source, TimeSpan duration, TimeProvider timeProvider) { - return new SkipTime(source, duration, timeProvider); + return new SkipTime(source, duration.Normalize(), timeProvider); } // SkipFrame @@ -30,7 +30,7 @@ public static Observable SkipFrame(this Observable source, int frameCou public static Observable SkipFrame(this Observable source, int frameCount, FrameProvider frameProvider) { - return new SkipFrame(source, frameCount, frameProvider); + return new SkipFrame(source, frameCount.NormalizeFrame(), frameProvider); } } diff --git a/src/R3/Operators/SkipLast.cs b/src/R3/Operators/SkipLast.cs index 8d5e01a6..aebfd06c 100644 --- a/src/R3/Operators/SkipLast.cs +++ b/src/R3/Operators/SkipLast.cs @@ -17,7 +17,7 @@ public static Observable SkipLast(this Observable source, TimeSpan dura public static Observable SkipLast(this Observable source, TimeSpan duration, TimeProvider timeProvider) { - return new SkipLastTime(source, duration, timeProvider); + return new SkipLastTime(source, duration.Normalize(), timeProvider); } // SkipLastFrame @@ -29,7 +29,7 @@ public static Observable SkipLastFrame(this Observable source, int fram public static Observable SkipLastFrame(this Observable source, int frameCount, FrameProvider frameProvider) { - return new SkipLastFrame(source, frameCount, frameProvider); + return new SkipLastFrame(source, frameCount.NormalizeFrame(), frameProvider); } } diff --git a/src/R3/Operators/SkipUntil.cs b/src/R3/Operators/SkipUntil.cs index 91487ac2..4fcee41d 100644 --- a/src/R3/Operators/SkipUntil.cs +++ b/src/R3/Operators/SkipUntil.cs @@ -17,6 +17,11 @@ public static Observable SkipUntil(this Observable source, Task task) { return new SkipUntilT(source, task); } + + public static Observable SkipUntil(this Observable source, Func asyncFunc, bool configureAwait = true) + { + return new SkipUntilAsync(source, asyncFunc, configureAwait); + } } internal sealed class SkipUntil(Observable source, Observable other) : Observable @@ -195,3 +200,70 @@ async void TaskAwait(Task task) } } } + +internal sealed class SkipUntilAsync(Observable source, Func asyncFunc, bool configureAwait) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return source.Subscribe(new _SkipUntil(observer, asyncFunc, configureAwait)); + } + + sealed class _SkipUntil(Observer observer, Func asyncFunc, bool configureAwait) : Observer, IDisposable + { + readonly CancellationTokenSource cancellationTokenSource = new(); + int isTaskRunning; + bool open; + + protected override void OnNextCore(T value) + { + var isFirstValue = (Interlocked.Exchange(ref isTaskRunning, 1) == 0); + if (isFirstValue) + { + TaskStart(value); + } + + if (Volatile.Read(ref open)) + { + observer.OnNext(value); + } + } + + protected override void OnErrorResumeCore(Exception error) + { + observer.OnErrorResume(error); + } + + protected override void OnCompletedCore(Result result) + { + cancellationTokenSource.Cancel(); // cancel executing async process first + observer.OnCompleted(result); + } + + protected override void DisposeCore() + { + cancellationTokenSource.Cancel(); + } + + async void TaskStart(T value) + { + try + { + await asyncFunc(value, cancellationTokenSource.Token).ConfigureAwait(configureAwait); + } + catch (Exception ex) + { + if (ex is OperationCanceledException oce && oce.CancellationToken == cancellationTokenSource.Token) + { + return; + } + + // error is Stop + observer.OnCompleted(Result.Failure(ex)); + return; + } + + Volatile.Write(ref open, true); + } + } +} + diff --git a/src/R3/Operators/Take.cs b/src/R3/Operators/Take.cs index fc5a2dac..468ece45 100644 --- a/src/R3/Operators/Take.cs +++ b/src/R3/Operators/Take.cs @@ -23,7 +23,7 @@ public static Observable Take(this Observable source, TimeSpan duration public static Observable Take(this Observable source, TimeSpan duration, TimeProvider timeProvider) { - return new TakeTime(source, duration, timeProvider); + return new TakeTime(source, duration.Normalize(), timeProvider); } // TakeFrame @@ -35,7 +35,7 @@ public static Observable TakeFrame(this Observable source, int frameCou public static Observable TakeFrame(this Observable source, int frameCount, FrameProvider frameProvider) { - return new TakeFrame(source, frameCount, frameProvider); + return new TakeFrame(source, frameCount.NormalizeFrame(), frameProvider); } } diff --git a/src/R3/Operators/TakeLast.cs b/src/R3/Operators/TakeLast.cs index 53e86551..2206b255 100644 --- a/src/R3/Operators/TakeLast.cs +++ b/src/R3/Operators/TakeLast.cs @@ -17,7 +17,7 @@ public static Observable TakeLast(this Observable source, TimeSpan dura public static Observable TakeLast(this Observable source, TimeSpan duration, TimeProvider timeProvider) { - return new TakeLastTime(source, duration, timeProvider); + return new TakeLastTime(source, duration.Normalize(), timeProvider); } // TakeLastFrame @@ -29,7 +29,7 @@ public static Observable TakeLastFrame(this Observable source, int fram public static Observable TakeLastFrame(this Observable source, int frameCount, FrameProvider frameProvider) { - return new TakeLastFrame(source, frameCount, frameProvider); + return new TakeLastFrame(source, frameCount.NormalizeFrame(), frameProvider); } } diff --git a/src/R3/Operators/TakeUntil.cs b/src/R3/Operators/TakeUntil.cs index a4a77ec7..3002762e 100644 --- a/src/R3/Operators/TakeUntil.cs +++ b/src/R3/Operators/TakeUntil.cs @@ -25,6 +25,11 @@ public static Observable TakeUntil(this Observable source, Task task) { return new TakeUntilT(source, task); } + + public static Observable TakeUntil(this Observable source, Func asyncFunc, bool configureAwait = true) + { + return new TakeUntilAsync(source, asyncFunc, configureAwait); + } } internal sealed class TakeUntil(Observable source, Observable other) : Observable @@ -190,3 +195,66 @@ async void TaskAwait(Task task) } } } + +internal sealed class TakeUntilAsync(Observable source, Func asyncFunc, bool configureAwait) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return source.Subscribe(new _TakeUntil(observer, asyncFunc, configureAwait)); + } + + sealed class _TakeUntil(Observer observer, Func asyncFunc, bool configureAwait) : Observer, IDisposable + { + readonly CancellationTokenSource cancellationTokenSource = new(); + int isTaskRunning; + + protected override void OnNextCore(T value) + { + var isFirstValue = (Interlocked.Exchange(ref isTaskRunning, 1) == 0); + if (isFirstValue) + { + TaskStart(value); + } + + observer.OnNext(value); + } + + protected override void OnErrorResumeCore(Exception error) + { + observer.OnErrorResume(error); + } + + protected override void OnCompletedCore(Result result) + { + cancellationTokenSource.Cancel(); // cancel executing async process first + observer.OnCompleted(result); + } + + protected override void DisposeCore() + { + cancellationTokenSource.Cancel(); + } + + async void TaskStart(T value) + { + + try + { + await asyncFunc(value, cancellationTokenSource.Token).ConfigureAwait(configureAwait); + } + catch (Exception ex) + { + if (ex is OperationCanceledException oce && oce.CancellationToken == cancellationTokenSource.Token) + { + return; + } + + // error is Stop + observer.OnCompleted(Result.Failure(ex)); + return; + } + + observer.OnCompleted(); + } + } +} diff --git a/tests/R3.Tests/OperatorTests/SkipUntilTest.cs b/tests/R3.Tests/OperatorTests/SkipUntilTest.cs index bdce6495..c0432f87 100644 --- a/tests/R3.Tests/OperatorTests/SkipUntilTest.cs +++ b/tests/R3.Tests/OperatorTests/SkipUntilTest.cs @@ -66,6 +66,30 @@ public async Task TaskT() await Task.Delay(100); // wait for completion + publisher1.OnNext(999999); + publisher1.OnNext(9999990); + + list.AssertEqual([999999, 9999990]); + publisher1.OnCompleted(); + list.AssertIsCompleted(); + } + + [Fact] + public void Async() + { + SynchronizationContext.SetSynchronizationContext(null); + + var publisher1 = new Subject(); + var tcs = new TaskCompletionSource(); + var list = publisher1.SkipUntil(async (x,ct) => await tcs.Task).ToLiveList(); + + publisher1.OnNext(1); + publisher1.OnNext(2); + publisher1.OnNext(3); + list.AssertEqual([]); + + tcs.TrySetResult(); + publisher1.OnNext(999999); publisher1.OnNext(9999990); diff --git a/tests/R3.Tests/OperatorTests/TakeUntilTest.cs b/tests/R3.Tests/OperatorTests/TakeUntilTest.cs index e9d16087..e8d08b93 100644 --- a/tests/R3.Tests/OperatorTests/TakeUntilTest.cs +++ b/tests/R3.Tests/OperatorTests/TakeUntilTest.cs @@ -61,4 +61,25 @@ public async Task TaskT() list.AssertEqual([1, 2, 3]); list.AssertIsCompleted(); } + + [Fact] + public void Async() + { + SynchronizationContext.SetSynchronizationContext(null); + + var publisher1 = new Subject(); + var tcs = new TaskCompletionSource(); + var list = publisher1.TakeUntil(async (x,ct) => await tcs.Task).ToLiveList(); + + publisher1.OnNext(1); + publisher1.OnNext(2); + publisher1.OnNext(3); + list.AssertEqual([1, 2, 3]); + + tcs.TrySetResult(); + + list.AssertEqual([1, 2, 3]); + list.AssertIsCompleted(); + + } }