Skip to content

Commit

Permalink
Merge pull request #155 from Cysharp/takeskip-async
Browse files Browse the repository at this point in the history
Add TakeUntil(async), SkipUntil(async)
  • Loading branch information
neuecc authored Mar 2, 2024
2 parents b484456 + 6659cb6 commit 302e520
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/R3/Operators/Skip.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static Observable<T> Skip<T>(this Observable<T> source, TimeSpan duration

public static Observable<T> Skip<T>(this Observable<T> source, TimeSpan duration, TimeProvider timeProvider)
{
return new SkipTime<T>(source, duration, timeProvider);
return new SkipTime<T>(source, duration.Normalize(), timeProvider);
}

// SkipFrame
Expand All @@ -30,7 +30,7 @@ public static Observable<T> SkipFrame<T>(this Observable<T> source, int frameCou

public static Observable<T> SkipFrame<T>(this Observable<T> source, int frameCount, FrameProvider frameProvider)
{
return new SkipFrame<T>(source, frameCount, frameProvider);
return new SkipFrame<T>(source, frameCount.NormalizeFrame(), frameProvider);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/R3/Operators/SkipLast.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static Observable<T> SkipLast<T>(this Observable<T> source, TimeSpan dura

public static Observable<T> SkipLast<T>(this Observable<T> source, TimeSpan duration, TimeProvider timeProvider)
{
return new SkipLastTime<T>(source, duration, timeProvider);
return new SkipLastTime<T>(source, duration.Normalize(), timeProvider);
}

// SkipLastFrame
Expand All @@ -29,7 +29,7 @@ public static Observable<T> SkipLastFrame<T>(this Observable<T> source, int fram

public static Observable<T> SkipLastFrame<T>(this Observable<T> source, int frameCount, FrameProvider frameProvider)
{
return new SkipLastFrame<T>(source, frameCount, frameProvider);
return new SkipLastFrame<T>(source, frameCount.NormalizeFrame(), frameProvider);
}
}

Expand Down
72 changes: 72 additions & 0 deletions src/R3/Operators/SkipUntil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public static Observable<T> SkipUntil<T>(this Observable<T> source, Task task)
{
return new SkipUntilT<T>(source, task);
}

public static Observable<T> SkipUntil<T>(this Observable<T> source, Func<T, CancellationToken, ValueTask> asyncFunc, bool configureAwait = true)
{
return new SkipUntilAsync<T>(source, asyncFunc, configureAwait);
}
}

internal sealed class SkipUntil<T, TOther>(Observable<T> source, Observable<TOther> other) : Observable<T>
Expand Down Expand Up @@ -195,3 +200,70 @@ async void TaskAwait(Task task)
}
}
}

internal sealed class SkipUntilAsync<T>(Observable<T> source, Func<T, CancellationToken, ValueTask> asyncFunc, bool configureAwait) : Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
return source.Subscribe(new _SkipUntil(observer, asyncFunc, configureAwait));
}

sealed class _SkipUntil(Observer<T> observer, Func<T, CancellationToken, ValueTask> asyncFunc, bool configureAwait) : Observer<T>, IDisposable
{
readonly CancellationTokenSource cancellationTokenSource = new();
int isTaskRunning;
bool open;

protected override void OnNextCore(T value)
{
var isFirstValue = (Interlocked.Exchange(ref isTaskRunning, 1) == 0);
if (isFirstValue)
{
TaskStart(value);
}

if (Volatile.Read(ref open))
{
observer.OnNext(value);
}
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void OnCompletedCore(Result result)
{
cancellationTokenSource.Cancel(); // cancel executing async process first
observer.OnCompleted(result);
}

protected override void DisposeCore()
{
cancellationTokenSource.Cancel();
}

async void TaskStart(T value)
{
try
{
await asyncFunc(value, cancellationTokenSource.Token).ConfigureAwait(configureAwait);
}
catch (Exception ex)
{
if (ex is OperationCanceledException oce && oce.CancellationToken == cancellationTokenSource.Token)
{
return;
}

// error is Stop
observer.OnCompleted(Result.Failure(ex));
return;
}

Volatile.Write(ref open, true);
}
}
}

4 changes: 2 additions & 2 deletions src/R3/Operators/Take.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static Observable<T> Take<T>(this Observable<T> source, TimeSpan duration

public static Observable<T> Take<T>(this Observable<T> source, TimeSpan duration, TimeProvider timeProvider)
{
return new TakeTime<T>(source, duration, timeProvider);
return new TakeTime<T>(source, duration.Normalize(), timeProvider);
}

// TakeFrame
Expand All @@ -35,7 +35,7 @@ public static Observable<T> TakeFrame<T>(this Observable<T> source, int frameCou

public static Observable<T> TakeFrame<T>(this Observable<T> source, int frameCount, FrameProvider frameProvider)
{
return new TakeFrame<T>(source, frameCount, frameProvider);
return new TakeFrame<T>(source, frameCount.NormalizeFrame(), frameProvider);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/R3/Operators/TakeLast.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static Observable<T> TakeLast<T>(this Observable<T> source, TimeSpan dura

public static Observable<T> TakeLast<T>(this Observable<T> source, TimeSpan duration, TimeProvider timeProvider)
{
return new TakeLastTime<T>(source, duration, timeProvider);
return new TakeLastTime<T>(source, duration.Normalize(), timeProvider);
}

// TakeLastFrame
Expand All @@ -29,7 +29,7 @@ public static Observable<T> TakeLastFrame<T>(this Observable<T> source, int fram

public static Observable<T> TakeLastFrame<T>(this Observable<T> source, int frameCount, FrameProvider frameProvider)
{
return new TakeLastFrame<T>(source, frameCount, frameProvider);
return new TakeLastFrame<T>(source, frameCount.NormalizeFrame(), frameProvider);
}
}

Expand Down
68 changes: 68 additions & 0 deletions src/R3/Operators/TakeUntil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public static Observable<T> TakeUntil<T>(this Observable<T> source, Task task)
{
return new TakeUntilT<T>(source, task);
}

public static Observable<T> TakeUntil<T>(this Observable<T> source, Func<T, CancellationToken, ValueTask> asyncFunc, bool configureAwait = true)
{
return new TakeUntilAsync<T>(source, asyncFunc, configureAwait);
}
}

internal sealed class TakeUntil<T, TOther>(Observable<T> source, Observable<TOther> other) : Observable<T>
Expand Down Expand Up @@ -190,3 +195,66 @@ async void TaskAwait(Task task)
}
}
}

internal sealed class TakeUntilAsync<T>(Observable<T> source, Func<T, CancellationToken, ValueTask> asyncFunc, bool configureAwait) : Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
return source.Subscribe(new _TakeUntil(observer, asyncFunc, configureAwait));
}

sealed class _TakeUntil(Observer<T> observer, Func<T, CancellationToken, ValueTask> asyncFunc, bool configureAwait) : Observer<T>, IDisposable
{
readonly CancellationTokenSource cancellationTokenSource = new();
int isTaskRunning;

protected override void OnNextCore(T value)
{
var isFirstValue = (Interlocked.Exchange(ref isTaskRunning, 1) == 0);
if (isFirstValue)
{
TaskStart(value);
}

observer.OnNext(value);
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void OnCompletedCore(Result result)
{
cancellationTokenSource.Cancel(); // cancel executing async process first
observer.OnCompleted(result);
}

protected override void DisposeCore()
{
cancellationTokenSource.Cancel();
}

async void TaskStart(T value)
{

try
{
await asyncFunc(value, cancellationTokenSource.Token).ConfigureAwait(configureAwait);
}
catch (Exception ex)
{
if (ex is OperationCanceledException oce && oce.CancellationToken == cancellationTokenSource.Token)
{
return;
}

// error is Stop
observer.OnCompleted(Result.Failure(ex));
return;
}

observer.OnCompleted();
}
}
}
24 changes: 24 additions & 0 deletions tests/R3.Tests/OperatorTests/SkipUntilTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ public async Task TaskT()
await Task.Delay(100); // wait for completion


publisher1.OnNext(999999);
publisher1.OnNext(9999990);

list.AssertEqual([999999, 9999990]);
publisher1.OnCompleted();
list.AssertIsCompleted();
}

[Fact]
public void Async()
{
SynchronizationContext.SetSynchronizationContext(null);

var publisher1 = new Subject<int>();
var tcs = new TaskCompletionSource();
var list = publisher1.SkipUntil(async (x,ct) => await tcs.Task).ToLiveList();

publisher1.OnNext(1);
publisher1.OnNext(2);
publisher1.OnNext(3);
list.AssertEqual([]);

tcs.TrySetResult();

publisher1.OnNext(999999);
publisher1.OnNext(9999990);

Expand Down
21 changes: 21 additions & 0 deletions tests/R3.Tests/OperatorTests/TakeUntilTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,25 @@ public async Task TaskT()
list.AssertEqual([1, 2, 3]);
list.AssertIsCompleted();
}

[Fact]
public void Async()
{
SynchronizationContext.SetSynchronizationContext(null);

var publisher1 = new Subject<int>();
var tcs = new TaskCompletionSource();
var list = publisher1.TakeUntil(async (x,ct) => await tcs.Task).ToLiveList();

publisher1.OnNext(1);
publisher1.OnNext(2);
publisher1.OnNext(3);
list.AssertEqual([1, 2, 3]);

tcs.TrySetResult();

list.AssertEqual([1, 2, 3]);
list.AssertIsCompleted();

}
}

0 comments on commit 302e520

Please sign in to comment.