diff --git a/src/R3/Factories/Return.cs b/src/R3/Factories/Return.cs index 4f93fb7c..e05ef237 100644 --- a/src/R3/Factories/Return.cs +++ b/src/R3/Factories/Return.cs @@ -1,32 +1,28 @@ -using System; - -namespace R3; +namespace R3; public static partial class Observable { - // TODO: CancellationToken? - public static Observable Return(T value) { return new ImmediateScheduleReturn(value); // immediate } - public static Observable Return(T value, TimeProvider timeProvider) + public static Observable Return(T value, TimeProvider timeProvider, CancellationToken cancellationToken = default) { - return Return(value, TimeSpan.Zero, timeProvider); + return Return(value, TimeSpan.Zero, timeProvider, cancellationToken); } - public static Observable Return(T value, TimeSpan dueTime, TimeProvider timeProvider) + public static Observable Return(T value, TimeSpan dueTime, TimeProvider timeProvider, CancellationToken cancellationToken = default) { if (dueTime == TimeSpan.Zero) { if (timeProvider == TimeProvider.System) { - return new ThreadPoolScheduleReturn(value); // optimize for SystemTimeProvidr, use ThreadPool.UnsafeQueueUserWorkItem + return new ThreadPoolScheduleReturn(value, cancellationToken); // optimize for SystemTimeProvidr, use ThreadPool.UnsafeQueueUserWorkItem } } - return new Return(value, dueTime.Normalize(), timeProvider); // use ITimer + return new Return(value, dueTime.Normalize(), timeProvider, cancellationToken); // use ITimer } // Optimized case @@ -53,27 +49,37 @@ public static Observable Return(int value) // util - public static Observable Yield() + public static Observable Yield(CancellationToken cancellationToken = default) { - return new ThreadPoolScheduleReturn(default); + return new ThreadPoolScheduleReturn(default, cancellationToken); } - public static Observable Yield(TimeProvider timeProvider) + public static Observable Yield(TimeProvider timeProvider, CancellationToken cancellationToken = default) { if (timeProvider == TimeProvider.System) { - return new ThreadPoolScheduleReturn(default); + return new ThreadPoolScheduleReturn(default, cancellationToken); } - return new Return(default, TimeSpan.Zero, timeProvider); + return new Return(default, TimeSpan.Zero, timeProvider, cancellationToken); } } -internal sealed class Return(T value, TimeSpan dueTime, TimeProvider timeProvider) : Observable +internal sealed class Return(T value, TimeSpan dueTime, TimeProvider timeProvider, CancellationToken cancellationToken) : Observable { protected override IDisposable SubscribeCore(Observer observer) { var method = new _Return(value, observer); method.Timer = timeProvider.CreateStoppedTimer(_Return.timerCallback, method); + + if (cancellationToken.CanBeCanceled) + { + method.cancellationTokenRegistration = cancellationToken.UnsafeRegister(static state => + { + var s = (_Return)state!; + s.CompleteDispose(); + }, method); + } + method.Timer.InvokeOnce(dueTime); return method; } @@ -82,6 +88,8 @@ sealed class _Return(T value, Observer observer) : IDisposable { public static readonly TimerCallback timerCallback = NextTick; + internal CancellationTokenRegistration cancellationTokenRegistration; + readonly T value = value; readonly Observer observer = observer; @@ -94,8 +102,15 @@ static void NextTick(object? state) self.observer.OnCompleted(); } + public void CompleteDispose() + { + observer.OnCompleted(); + Dispose(); + } + public void Dispose() { + cancellationTokenRegistration.Dispose(); Timer?.Dispose(); Timer = null; } @@ -112,11 +127,21 @@ protected override IDisposable SubscribeCore(Observer observer) } } -internal sealed class ThreadPoolScheduleReturn(T value) : Observable +internal sealed class ThreadPoolScheduleReturn(T value, CancellationToken cancellationToken) : Observable { protected override IDisposable SubscribeCore(Observer observer) { var method = new _Return(value, observer); + + if (cancellationToken.CanBeCanceled) + { + method.cancellationTokenRegistration = cancellationToken.UnsafeRegister(static state => + { + var s = (_Return)state!; + s.CompleteDispose(); + }, method); + } + ThreadPool.UnsafeQueueUserWorkItem(method, preferLocal: false); return method; } @@ -125,6 +150,8 @@ sealed class _Return(T value, Observer observer) : IDisposable, IThreadPoolWo { bool stop; + internal CancellationTokenRegistration cancellationTokenRegistration; + public void Execute() { if (stop) return; @@ -133,8 +160,15 @@ public void Execute() observer.OnCompleted(); } + public void CompleteDispose() + { + observer.OnCompleted(); + Dispose(); + } + public void Dispose() { + cancellationTokenRegistration.Dispose(); stop = true; } }