Skip to content

Commit

Permalink
Return CancellationToken
Browse files Browse the repository at this point in the history
  • Loading branch information
neuecc committed Jan 4, 2024
1 parent 71b9e7f commit 8a0c97d
Showing 1 changed file with 51 additions and 17 deletions.
68 changes: 51 additions & 17 deletions src/R3/Factories/Return.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
using System;

namespace R3;
namespace R3;

public static partial class Observable
{
// TODO: CancellationToken?

public static Observable<T> Return<T>(T value)
{
return new ImmediateScheduleReturn<T>(value); // immediate
}

public static Observable<T> Return<T>(T value, TimeProvider timeProvider)
public static Observable<T> Return<T>(T value, TimeProvider timeProvider, CancellationToken cancellationToken = default)
{
return Return(value, TimeSpan.Zero, timeProvider);
return Return(value, TimeSpan.Zero, timeProvider, cancellationToken);
}

public static Observable<T> Return<T>(T value, TimeSpan dueTime, TimeProvider timeProvider)
public static Observable<T> Return<T>(T value, TimeSpan dueTime, TimeProvider timeProvider, CancellationToken cancellationToken = default)
{
if (dueTime == TimeSpan.Zero)
{
if (timeProvider == TimeProvider.System)
{
return new ThreadPoolScheduleReturn<T>(value); // optimize for SystemTimeProvidr, use ThreadPool.UnsafeQueueUserWorkItem
return new ThreadPoolScheduleReturn<T>(value, cancellationToken); // optimize for SystemTimeProvidr, use ThreadPool.UnsafeQueueUserWorkItem
}
}

return new Return<T>(value, dueTime.Normalize(), timeProvider); // use ITimer
return new Return<T>(value, dueTime.Normalize(), timeProvider, cancellationToken); // use ITimer
}

// Optimized case
Expand All @@ -53,27 +49,37 @@ public static Observable<int> Return(int value)

// util

public static Observable<Unit> Yield()
public static Observable<Unit> Yield(CancellationToken cancellationToken = default)
{
return new ThreadPoolScheduleReturn<Unit>(default);
return new ThreadPoolScheduleReturn<Unit>(default, cancellationToken);
}

public static Observable<Unit> Yield(TimeProvider timeProvider)
public static Observable<Unit> Yield(TimeProvider timeProvider, CancellationToken cancellationToken = default)
{
if (timeProvider == TimeProvider.System)
{
return new ThreadPoolScheduleReturn<Unit>(default);
return new ThreadPoolScheduleReturn<Unit>(default, cancellationToken);
}
return new Return<Unit>(default, TimeSpan.Zero, timeProvider);
return new Return<Unit>(default, TimeSpan.Zero, timeProvider, cancellationToken);
}
}

internal sealed class Return<T>(T value, TimeSpan dueTime, TimeProvider timeProvider) : Observable<T>
internal sealed class Return<T>(T value, TimeSpan dueTime, TimeProvider timeProvider, CancellationToken cancellationToken) : Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
var method = new _Return(value, observer);
method.Timer = timeProvider.CreateStoppedTimer(_Return.timerCallback, method);

if (cancellationToken.CanBeCanceled)
{
method.cancellationTokenRegistration = cancellationToken.UnsafeRegister(static state =>
{
var s = (_Return)state!;
s.CompleteDispose();
}, method);
}

method.Timer.InvokeOnce(dueTime);
return method;
}
Expand All @@ -82,6 +88,8 @@ sealed class _Return(T value, Observer<T> observer) : IDisposable
{
public static readonly TimerCallback timerCallback = NextTick;

internal CancellationTokenRegistration cancellationTokenRegistration;

readonly T value = value;
readonly Observer<T> observer = observer;

Expand All @@ -94,8 +102,15 @@ static void NextTick(object? state)
self.observer.OnCompleted();
}

public void CompleteDispose()
{
observer.OnCompleted();
Dispose();
}

public void Dispose()
{
cancellationTokenRegistration.Dispose();
Timer?.Dispose();
Timer = null;
}
Expand All @@ -112,11 +127,21 @@ protected override IDisposable SubscribeCore(Observer<T> observer)
}
}

internal sealed class ThreadPoolScheduleReturn<T>(T value) : Observable<T>
internal sealed class ThreadPoolScheduleReturn<T>(T value, CancellationToken cancellationToken) : Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
var method = new _Return(value, observer);

if (cancellationToken.CanBeCanceled)
{
method.cancellationTokenRegistration = cancellationToken.UnsafeRegister(static state =>
{
var s = (_Return)state!;
s.CompleteDispose();
}, method);
}

ThreadPool.UnsafeQueueUserWorkItem(method, preferLocal: false);
return method;
}
Expand All @@ -125,6 +150,8 @@ sealed class _Return(T value, Observer<T> observer) : IDisposable, IThreadPoolWo
{
bool stop;

internal CancellationTokenRegistration cancellationTokenRegistration;

public void Execute()
{
if (stop) return;
Expand All @@ -133,8 +160,15 @@ public void Execute()
observer.OnCompleted();
}

public void CompleteDispose()
{
observer.OnCompleted();
Dispose();
}

public void Dispose()
{
cancellationTokenRegistration.Dispose();
stop = true;
}
}
Expand Down

0 comments on commit 8a0c97d

Please sign in to comment.