From edf8ee6c123316614efca04edf89406b4f2b902a Mon Sep 17 00:00:00 2001 From: neuecc Date: Sat, 6 Jan 2024 15:22:20 +0900 Subject: [PATCH] FromEvent supports CancellationToken --- src/R3/Factories/FromEvent.cs | 70 +++++++++++++++----- tests/R3.Tests/FactoryTests/FromEventTest.cs | 34 ++++++++++ 2 files changed, 86 insertions(+), 18 deletions(-) diff --git a/src/R3/Factories/FromEvent.cs b/src/R3/Factories/FromEvent.cs index cf4e9bd5..6eb2c749 100644 --- a/src/R3/Factories/FromEvent.cs +++ b/src/R3/Factories/FromEvent.cs @@ -2,43 +2,43 @@ public static partial class Observable { - public static Observable<(object? sender, EventArgs e)> FromEventHandler(Action addHandler, Action removeHandler) + public static Observable<(object? sender, EventArgs e)> FromEventHandler(Action addHandler, Action removeHandler, CancellationToken cancellationToken = default) { - return new FromEvent(h => (sender, e) => h((sender, e)), addHandler, removeHandler); + return new FromEvent(h => (sender, e) => h((sender, e)), addHandler, removeHandler, cancellationToken); } - public static Observable<(object? sender, TEventArgs e)> FromEventHandler(Action> addHandler, Action> removeHandler) + public static Observable<(object? sender, TEventArgs e)> FromEventHandler(Action> addHandler, Action> removeHandler, CancellationToken cancellationToken = default) { - return new FromEvent, (object? sender, TEventArgs e)>(h => (sender, e) => h((sender, e)), addHandler, removeHandler); + return new FromEvent, (object? sender, TEventArgs e)>(h => (sender, e) => h((sender, e)), addHandler, removeHandler, cancellationToken); } - public static Observable FromEvent(Action addHandler, Action removeHandler) + public static Observable FromEvent(Action addHandler, Action removeHandler, CancellationToken cancellationToken = default) { - return new FromEvent(static h => h, addHandler, removeHandler); + return new FromEvent(static h => h, addHandler, removeHandler, cancellationToken); } - public static Observable FromEvent(Action> addHandler, Action> removeHandler) + public static Observable FromEvent(Action> addHandler, Action> removeHandler, CancellationToken cancellationToken = default) { - return new FromEvent, T>(static h => h, addHandler, removeHandler); + return new FromEvent, T>(static h => h, addHandler, removeHandler, cancellationToken); } - public static Observable FromEvent(Func conversion, Action addHandler, Action removeHandler) + public static Observable FromEvent(Func conversion, Action addHandler, Action removeHandler, CancellationToken cancellationToken = default) { - return new FromEvent(conversion, addHandler, removeHandler); + return new FromEvent(conversion, addHandler, removeHandler, cancellationToken); } - public static Observable FromEvent(Func, TDelegate> conversion, Action addHandler, Action removeHandler) + public static Observable FromEvent(Func, TDelegate> conversion, Action addHandler, Action removeHandler, CancellationToken cancellationToken = default) { - return new FromEvent(conversion, addHandler, removeHandler); + return new FromEvent(conversion, addHandler, removeHandler, cancellationToken); } } -internal sealed class FromEvent(Func conversion, Action addHandler, Action removeHandler) +internal sealed class FromEvent(Func conversion, Action addHandler, Action removeHandler, CancellationToken cancellationToken) : Observable { protected override IDisposable SubscribeCore(Observer observer) { - return new _FromEventPattern(conversion, addHandler, removeHandler, observer); + return new _FromEventPattern(conversion, addHandler, removeHandler, observer, cancellationToken); } sealed class _FromEventPattern : IDisposable @@ -46,13 +46,23 @@ sealed class _FromEventPattern : IDisposable Observer? observer; Action? removeHandler; TDelegate registeredHandler; + CancellationTokenRegistration cancellationTokenRegistration; - public _FromEventPattern(Func conversion, Action addHandler, Action removeHandler, Observer observer) + public _FromEventPattern(Func conversion, Action addHandler, Action removeHandler, Observer observer, CancellationToken cancellationToken) { this.observer = observer; this.removeHandler = removeHandler; this.registeredHandler = conversion(OnNext); addHandler(this.registeredHandler); + + if (cancellationToken.CanBeCanceled) + { + this.cancellationTokenRegistration = cancellationToken.UnsafeRegister(static state => + { + var s = (_FromEventPattern)state!; + s.CompleteDispose(); + }, this); + } } void OnNext() @@ -60,6 +70,12 @@ void OnNext() observer?.OnNext(default); } + void CompleteDispose() + { + observer?.OnCompleted(); + Dispose(); + } + public void Dispose() { var handler = Interlocked.Exchange(ref removeHandler, null); @@ -67,18 +83,19 @@ public void Dispose() { observer = null; removeHandler = null; + cancellationTokenRegistration.Dispose(); handler(this.registeredHandler); } } } } -internal sealed class FromEvent(Func, TDelegate> conversion, Action addHandler, Action removeHandler) +internal sealed class FromEvent(Func, TDelegate> conversion, Action addHandler, Action removeHandler, CancellationToken cancellationToken) : Observable { protected override IDisposable SubscribeCore(Observer observer) { - return new _FromEventPattern(conversion, addHandler, removeHandler, observer); + return new _FromEventPattern(conversion, addHandler, removeHandler, observer, cancellationToken); } sealed class _FromEventPattern : IDisposable @@ -86,13 +103,23 @@ sealed class _FromEventPattern : IDisposable Observer? observer; Action? removeHandler; TDelegate registeredHandler; + CancellationTokenRegistration cancellationTokenRegistration; - public _FromEventPattern(Func, TDelegate> conversion, Action addHandler, Action removeHandler, Observer observer) + public _FromEventPattern(Func, TDelegate> conversion, Action addHandler, Action removeHandler, Observer observer, CancellationToken cancellationToken) { this.observer = observer; this.removeHandler = removeHandler; this.registeredHandler = conversion(OnNext); addHandler(this.registeredHandler); + + if (cancellationToken.CanBeCanceled) + { + this.cancellationTokenRegistration = cancellationToken.UnsafeRegister(static state => + { + var s = (_FromEventPattern)state!; + s.CompleteDispose(); + }, this); + } } void OnNext(T value) @@ -100,6 +127,12 @@ void OnNext(T value) observer?.OnNext(value); } + void CompleteDispose() + { + observer?.OnCompleted(); + Dispose(); + } + public void Dispose() { var handler = Interlocked.Exchange(ref removeHandler, null); @@ -107,6 +140,7 @@ public void Dispose() { observer = null; removeHandler = null; + cancellationTokenRegistration.Dispose(); handler(this.registeredHandler); } } diff --git a/tests/R3.Tests/FactoryTests/FromEventTest.cs b/tests/R3.Tests/FactoryTests/FromEventTest.cs index 062df3df..3ca9391f 100644 --- a/tests/R3.Tests/FactoryTests/FromEventTest.cs +++ b/tests/R3.Tests/FactoryTests/FromEventTest.cs @@ -41,6 +41,40 @@ public void Event() ev.InvocationListCount().Should().Be((0, 0, 0, 0, 0, 0, 0)); } + [Fact] + public void Cancel() + { + var cts = new CancellationTokenSource(); + + var ev = new EventPattern(); + + var l1 = Observable.FromEventHandler(h => ev.E1 += h, h => ev.E1 -= h, cts.Token).ToLiveList(); + var l2 = Observable.FromEventHandler(h => ev.E2 += h, h => ev.E2 -= h, cts.Token).ToLiveList(); + var l3 = Observable.FromEvent(h => ev.A1 += h, h => ev.A1 -= h, cts.Token).ToLiveList(); + var l4 = Observable.FromEvent(h => ev.A2 += h, h => ev.A2 -= h, cts.Token).ToLiveList(); + var l5 = Observable.FromEvent(h => new MyDelegate1(h), h => ev.M1 += h, h => ev.M1 -= h, cts.Token).ToLiveList(); + var l6 = Observable.FromEvent(h => new MyDelegate2(h), h => ev.M2 += h, h => ev.M2 -= h, cts.Token).ToLiveList(); + var l7 = Observable.FromEvent(h => (x, y) => h((x, y)), h => ev.M3 += h, h => ev.M3 -= h, cts.Token).ToLiveList(); + + ev.Raise(10, 20); + ev.Raise(100, 200); + + l1.Should().HaveCount(2); + l3.Should().HaveCount(2); + l5.Should().HaveCount(2); + + l2.Select(x => x.e).Should().Equal([10, 100]); + l4.AssertEqual([10, 100]); + l6.AssertEqual([10, 100]); + l7.AssertEqual([(10, 20), (100, 200)]); + + ev.InvocationListCount().Should().Be((1, 1, 1, 1, 1, 1, 1)); + + cts.Cancel(); + + ev.InvocationListCount().Should().Be((0, 0, 0, 0, 0, 0, 0)); + } + class EventPattern {