From 19e93155ada36a17d809e8170c305188360ca23b Mon Sep 17 00:00:00 2001 From: neuecc Date: Mon, 1 Jan 2024 23:25:54 +0900 Subject: [PATCH] SubscribeOn, Synchronize --- src/R3/Operators/ObserveOn.cs | 9 +- src/R3/Operators/SubscribeOn.cs | 281 ++++++++++++++++++ src/R3/Operators/Synchronize.cs | 50 ++++ src/R3/Operators/_Operators.cs | 2 - .../R3.Tests/OperatorTests/SubscribeOnTest.cs | 128 ++++++++ .../OperatorTests/SynchronizationTest.cs | 20 ++ 6 files changed, 486 insertions(+), 4 deletions(-) create mode 100644 src/R3/Operators/SubscribeOn.cs create mode 100644 src/R3/Operators/Synchronize.cs create mode 100644 tests/R3.Tests/OperatorTests/SubscribeOnTest.cs create mode 100644 tests/R3.Tests/OperatorTests/SynchronizationTest.cs diff --git a/src/R3/Operators/ObserveOn.cs b/src/R3/Operators/ObserveOn.cs index 3d90b3f2..ce538199 100644 --- a/src/R3/Operators/ObserveOn.cs +++ b/src/R3/Operators/ObserveOn.cs @@ -5,16 +5,21 @@ namespace R3; public static partial class ObservableExtensions { /// ObserveOn SynchronizationContext.Current - public static Observable ObserveOnCurrent(this Observable source) + public static Observable ObserveOnCurrentSynchronizationContext(this Observable source) { return ObserveOn(source, SynchronizationContext.Current); } + public static Observable ObserveOnThreadPool(this Observable source) + { + return new ObserveOnThreadPool(source); + } + public static Observable ObserveOn(this Observable source, SynchronizationContext? synchronizationContext) { if (synchronizationContext == null) { - return ObserveOn(source, TimeProvider.System); // use ThreadPool instead + return new ObserveOnThreadPool(source); // use ThreadPool instead } return new ObserveOnSynchronizationContext(source, synchronizationContext); diff --git a/src/R3/Operators/SubscribeOn.cs b/src/R3/Operators/SubscribeOn.cs new file mode 100644 index 00000000..d7333012 --- /dev/null +++ b/src/R3/Operators/SubscribeOn.cs @@ -0,0 +1,281 @@ +namespace R3; + +public static partial class ObservableExtensions +{ + public static Observable SubscribeOnCurrentSynchronizationContext(this Observable source) + { + return SubscribeOn(source, SynchronizationContext.Current); + } + + public static Observable SubscribeOnThreadPool(this Observable source) + { + return new SubscribeOnThreadPool(source); + } + + public static Observable SubscribeOn(this Observable source, SynchronizationContext? synchronizationContext) + { + if (synchronizationContext == null) + { + return new SubscribeOnThreadPool(source); // use ThreadPool instead + } + + return new SubscribeOnSynchronizationContext(source, synchronizationContext); + } + + public static Observable SubscribeOn(this Observable source, TimeProvider timeProvider) + { + if (timeProvider == TimeProvider.System) + { + return new SubscribeOnThreadPool(source); + } + + return new SubscribeOnTimeProvider(source, timeProvider); + } + + public static Observable SubscribeOn(this Observable source, FrameProvider frameProvider) + { + return new SubscribeOnFrameProvider(source, frameProvider); + } +} + +internal sealed class SubscribeOnSynchronizationContext(Observable source, SynchronizationContext synchronizationContext) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return new _SubscribeOn(observer, source, synchronizationContext).Run(); + } + + sealed class _SubscribeOn : Observer + { + static readonly SendOrPostCallback postCallback = Subscribe; + + readonly Observer observer; + readonly Observable source; + readonly SynchronizationContext synchronizationContext; + SingleAssignmentDisposableCore disposable; + + public _SubscribeOn(Observer observer, Observable source, SynchronizationContext synchronizationContext) + { + this.observer = observer; + this.source = source; + this.synchronizationContext = synchronizationContext; + } + + public IDisposable Run() + { + synchronizationContext.Post(postCallback, this); + return this; + } + + static void Subscribe(object? state) + { + var self = (_SubscribeOn)state!; + self.disposable.Disposable = self.source.Subscribe(self); + } + + protected override void OnNextCore(T value) + { + observer.OnNext(value); + } + + protected override void OnErrorResumeCore(Exception error) + { + observer.OnErrorResume(error); + } + + protected override void OnCompletedCore(Result result) + { + observer.OnCompleted(result); + } + + protected override void DisposeCore() + { + disposable.Dispose(); + } + } +} + +internal sealed class SubscribeOnThreadPool(Observable source) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return new _SubscribeOn(observer, source).Run(); + } + + sealed class _SubscribeOn : Observer, IThreadPoolWorkItem + { + readonly Observer observer; + readonly Observable source; + SingleAssignmentDisposableCore disposable; + + public _SubscribeOn(Observer observer, Observable source) + { + this.observer = observer; + this.source = source; + } + + public IDisposable Run() + { + ThreadPool.UnsafeQueueUserWorkItem(this, preferLocal: false); + return this; + } + + public void Execute() + { + try + { + disposable.Disposable = source.Subscribe(this); + } + catch (Exception ex) + { + ObservableSystem.GetUnhandledExceptionHandler().Invoke(ex); + Dispose(); + } + } + + protected override void OnNextCore(T value) + { + observer.OnNext(value); + } + + protected override void OnErrorResumeCore(Exception error) + { + observer.OnErrorResume(error); + } + + protected override void OnCompletedCore(Result result) + { + observer.OnCompleted(result); + } + + protected override void DisposeCore() + { + disposable.Dispose(); + } + } +} + +internal sealed class SubscribeOnTimeProvider(Observable source, TimeProvider timeProvider) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return new _SubscribeOn(observer, source, timeProvider).Run(); + } + + sealed class _SubscribeOn : Observer + { + static readonly TimerCallback timerCallback = Subscribe; + + readonly Observer observer; + readonly Observable source; + readonly TimeProvider timeProvider; + readonly ITimer timer; + SingleAssignmentDisposableCore disposable; + + public _SubscribeOn(Observer observer, Observable source, TimeProvider timeProvider) + { + this.observer = observer; + this.source = source; + this.timeProvider = timeProvider; + this.timer = timeProvider.CreateStoppedTimer(timerCallback, this); + } + + public IDisposable Run() + { + timer.RestartImmediately(); + return this; + } + + static void Subscribe(object? state) + { + var self = (_SubscribeOn)state!; + self.disposable.Disposable = self.source.Subscribe(self); + } + + protected override void OnNextCore(T value) + { + observer.OnNext(value); + } + + protected override void OnErrorResumeCore(Exception error) + { + observer.OnErrorResume(error); + } + + protected override void OnCompletedCore(Result result) + { + observer.OnCompleted(result); + } + + protected override void DisposeCore() + { + timer.Dispose(); + disposable.Dispose(); + } + } +} + +internal sealed class SubscribeOnFrameProvider(Observable source, FrameProvider frameProvider) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return new _SubscribeOn(observer, source, frameProvider).Run(); + } + + sealed class _SubscribeOn : Observer, IFrameRunnerWorkItem + { + static readonly SendOrPostCallback postCallback = Subscribe; + + readonly Observer observer; + readonly Observable source; + readonly FrameProvider frameProvider; + SingleAssignmentDisposableCore disposable; + + public _SubscribeOn(Observer observer, Observable source, FrameProvider frameProvider) + { + this.observer = observer; + this.source = source; + this.frameProvider = frameProvider; + } + + public IDisposable Run() + { + frameProvider.Register(this); + return this; + } + + static void Subscribe(object? state) + { + var self = (_SubscribeOn)state!; + self.disposable.Disposable = self.source.Subscribe(self); + } + + bool IFrameRunnerWorkItem.MoveNext(long frameCount) + { + if (disposable.IsDisposed) return false; + + disposable.Disposable = source.Subscribe(this); + return false; + } + + protected override void OnNextCore(T value) + { + observer.OnNext(value); + } + + protected override void OnErrorResumeCore(Exception error) + { + observer.OnErrorResume(error); + } + + protected override void OnCompletedCore(Result result) + { + observer.OnCompleted(result); + } + + protected override void DisposeCore() + { + disposable.Dispose(); + } + } +} diff --git a/src/R3/Operators/Synchronize.cs b/src/R3/Operators/Synchronize.cs new file mode 100644 index 00000000..4ab42f42 --- /dev/null +++ b/src/R3/Operators/Synchronize.cs @@ -0,0 +1,50 @@ +namespace R3; + +public static partial class ObservableExtensions +{ + public static Observable Synchronize(this Observable source) + { + return new Synchronize(source, new object()); + } + + public static Observable Synchronize(this Observable source, object gate) + { + return new Synchronize(source, gate); + } +} + + +internal sealed class Synchronize(Observable source, object gate) : Observable +{ + protected override IDisposable SubscribeCore(Observer observer) + { + return source.Subscribe(new _Synchronize(observer, gate)); + } + + sealed class _Synchronize(Observer observer, object gate) : Observer + { + protected override void OnNextCore(T value) + { + lock (gate) + { + observer.OnNext(value); + } + } + + protected override void OnErrorResumeCore(Exception error) + { + lock (gate) + { + observer.OnErrorResume(error); + } + } + + protected override void OnCompletedCore(Result result) + { + lock (gate) + { + observer.OnCompleted(result); + } + } + } +} diff --git a/src/R3/Operators/_Operators.cs b/src/R3/Operators/_Operators.cs index 930d6fc6..f9ef7b27 100644 --- a/src/R3/Operators/_Operators.cs +++ b/src/R3/Operators/_Operators.cs @@ -15,8 +15,6 @@ public static partial class ObservableExtensions // Buffer + BUfferFrame => Chunk, ChunkFrame - // SubscribeOn, Synchronize - // Rx Merging: // CombineLatest, Zip, WithLatestFrom, ZipLatest, Switch diff --git a/tests/R3.Tests/OperatorTests/SubscribeOnTest.cs b/tests/R3.Tests/OperatorTests/SubscribeOnTest.cs new file mode 100644 index 00000000..4dab58c2 --- /dev/null +++ b/tests/R3.Tests/OperatorTests/SubscribeOnTest.cs @@ -0,0 +1,128 @@ +using Newtonsoft.Json.Linq; + +namespace R3.Tests.OperatorTests; + +public class SubscribeOnTest +{ + + // null synccontext(TimeProvider.System) + [Fact] + public async Task ThreadPool() + { + var values = await Observable.Range(1, 10) + .Do(onSubscribe: () => Thread.CurrentThread.IsThreadPoolThread.Should().BeTrue()) + .SubscribeOnThreadPool() + .ToArrayAsync(); + values.Should().Equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + [Fact] + public async Task SyncContext() + { + var syncContext = new CustomSyncContext(); + syncContext.IsInSyncContext.Should().BeFalse(); + var values = await Observable.Range(1, 10) + .Do(onSubscribe: () => syncContext.IsInSyncContext.Should().BeTrue()) + .SubscribeOn(syncContext) + .ToArrayAsync(); + values.Should().Equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + syncContext.PostCount.Should().Be(1); + } + + [Fact] + public void TimeProvider() + { + var fakeTime = new ImmediateFakeTiemr(); + var subscribed = false; + using var list = Observable.Range(1, 10) + .Do(onSubscribe: () => subscribed = true) + .SubscribeOn(fakeTime) + .ToLiveList(); + + + subscribed.Should().BeFalse(); + + fakeTime.Advance(); + subscribed.Should().BeTrue(); + + list.AssertEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + + [Fact] + public void FrameProvider() + { + + var fakeTime = new ManualFrameProvider(); + var subscribed = false; + using var list = Observable.Range(1, 10) + .Do(onSubscribe: () => subscribed = true) + .SubscribeOn(fakeTime) + .ToLiveList(); + + + subscribed.Should().BeFalse(); + + fakeTime.Advance(); + subscribed.Should().BeTrue(); + + list.AssertEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } +} + +file class CustomSyncContext : SynchronizationContext +{ + public int PostCount; + public bool IsInSyncContext { get; set; } + + public override void Post(SendOrPostCallback d, object? state) + { + IsInSyncContext = true; + PostCount++; + d(state); + IsInSyncContext = false; + } +} + +file class ImmediateFakeTiemr : TimeProvider +{ + List timers = new(); + + public override ITimer CreateTimer(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period) + { + var t = new Timer(callback, state); + timers.Add(t); + return t; + } + + public void Advance() + { + foreach (var item in timers) + { + item.Wakeup(); + } + } + + class Timer(TimerCallback callback, object? state) : ITimer + { + public bool Change(TimeSpan dueTime, TimeSpan period) + { + return false; + } + + public bool Wakeup() + { + callback(state); + return true; + } + + public void Dispose() + { + } + + public ValueTask DisposeAsync() + { + return default; + } + } +} diff --git a/tests/R3.Tests/OperatorTests/SynchronizationTest.cs b/tests/R3.Tests/OperatorTests/SynchronizationTest.cs new file mode 100644 index 00000000..b98c2bba --- /dev/null +++ b/tests/R3.Tests/OperatorTests/SynchronizationTest.cs @@ -0,0 +1,20 @@ +namespace R3.Tests.OperatorTests; + +public class SynchronizationTest(ITestOutputHelper output) +{ + [Fact] + public void Test() + { + var subject = new Subject(); + + var count = 0; + var no_sync = 0; + subject.Subscribe(x => no_sync++); + subject.Synchronize().Subscribe(x => count++); + + Parallel.For(0, 100, x => subject.OnNext(x)); + + count.Should().Be(100); + output.WriteLine($"Count: {count}, no_sync: {no_sync}"); + } +}