Skip to content

Commit

Permalink
Where(State), Select
Browse files Browse the repository at this point in the history
  • Loading branch information
neuecc committed Dec 24, 2023
1 parent a7149ec commit 503d002
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/R3/Operators/Do.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public static Observable<T> Do<T, TState>(this Observable<T> source, TState stat

public static Observable<T> CancelOnCompleted<T>(this Observable<T> source, CancellationTokenSource cancellationTokenSource)
{
return Do(source, cancellationTokenSource, onCompleted: (_, state) => state.Cancel());
return Do(source, cancellationTokenSource, onCompleted: static (_, state) => state.Cancel());
}
}

Expand Down
77 changes: 72 additions & 5 deletions src/R3/Operators/Select.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,33 @@

public static partial class ObservableExtensions
{
// TODO: Element index overload

public static Observable<TResult> Select<T, TResult>(this Observable<T> source, Func<T, TResult> selector)
{
if (source is Where<T> where)
{
return new WhereSelect<T, TResult>(source, selector, where.predicate);
// Optimize for WhereSelect
return new WhereSelect<T, TResult>(where.source, selector, where.predicate);
}

return new Select<T, TResult>(source, selector);
}

public static Observable<TResult> Select<T, TResult>(this Observable<T> source, Func<T, int, TResult> selector)
{
return new SelectIndexed<T, TResult>(source, selector);
}

// TState

public static Observable<TResult> Select<T, TResult, TState>(this Observable<T> source, TState state, Func<T, TState, TResult> selector)
{
return new Select<T, TResult, TState>(source, selector, state);
}

public static Observable<TResult> Select<T, TResult, TState>(this Observable<T> source, TState state, Func<T, int, TState, TResult> selector)
{
return new SelectIndexed<T, TResult, TState>(source, selector, state);
}
}

internal sealed class Select<T, TResult>(Observable<T> source, Func<T, TResult> selector) : Observable<TResult>
Expand Down Expand Up @@ -76,10 +87,10 @@ internal sealed class WhereSelect<T, TResult>(Observable<T> source, Func<T, TRes
{
protected override IDisposable SubscribeCore(Observer<TResult> observer)
{
return source.Subscribe(new _Select(observer, selector, predicate));
return source.Subscribe(new _WhereSelect(observer, selector, predicate));
}

sealed class _Select(Observer<TResult> observer, Func<T, TResult> selector, Func<T, bool> predicate) : Observer<T>
sealed class _WhereSelect(Observer<TResult> observer, Func<T, TResult> selector, Func<T, bool> predicate) : Observer<T>
{
protected override void OnNextCore(T value)
{
Expand All @@ -100,3 +111,59 @@ protected override void OnCompletedCore(Result result)
}
}
}

internal sealed class SelectIndexed<T, TResult>(Observable<T> source, Func<T, int, TResult> selector) : Observable<TResult>
{
protected override IDisposable SubscribeCore(Observer<TResult> observer)
{
return source.Subscribe(new _Select(observer, selector));
}

sealed class _Select(Observer<TResult> observer, Func<T, int, TResult> selector) : Observer<T>
{
int index = 0;

protected override void OnNextCore(T value)
{
observer.OnNext(selector(value, index++));
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void OnCompletedCore(Result result)
{
observer.OnCompleted(result);
}
}
}

internal sealed class SelectIndexed<T, TResult, TState>(Observable<T> source, Func<T, int, TState, TResult> selector, TState state) : Observable<TResult>
{
protected override IDisposable SubscribeCore(Observer<TResult> observer)
{
return source.Subscribe(new _Select(observer, selector, state));
}

sealed class _Select(Observer<TResult> observer, Func<T, int, TState, TResult> selector, TState state) : Observer<T>
{
int index = 0;

protected override void OnNextCore(T value)
{
observer.OnNext(selector(value, index++, state));
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void OnCompletedCore(Result result)
{
observer.OnCompleted(result);
}
}
}
76 changes: 73 additions & 3 deletions src/R3/Operators/Where.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

public static partial class ObservableExtensions
{
// TODO: TState

public static Observable<T> Where<T>(this Observable<T> source, Func<T, bool> predicate)
{
if (source is Where<T> where)
{
// Optimize for Where.Where, create combined predicate.
var p = where.predicate;
return new Where<T>(where.source, x => p(x) && predicate(x));
return new Where<T>(where.source, x => p(x) && predicate(x)); // lambda captured but don't use TState to allow combine more Where
}

return new Where<T>(source, predicate);
Expand All @@ -20,6 +18,18 @@ public static Observable<T> Where<T>(this Observable<T> source, Func<T, int, boo
{
return new WhereIndexed<T>(source, predicate);
}

// TState

public static Observable<T> Where<T, TState>(this Observable<T> source, TState state, Func<T, TState, bool> predicate)
{
return new Where<T, TState>(source, predicate, state);
}

public static Observable<T> Where<T, TState>(this Observable<T> source, TState state, Func<T, int, TState, bool> predicate)
{
return new WhereIndexed<T, TState>(source, predicate, state);
}
}

internal sealed class Where<T>(Observable<T> source, Func<T, bool> predicate) : Observable<T>
Expand Down Expand Up @@ -84,3 +94,63 @@ protected override void OnCompletedCore(Result result)
}
}
}

internal sealed class Where<T, TState>(Observable<T> source, Func<T, TState, bool> predicate, TState state) : Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
return source.Subscribe(new _Where(observer, predicate, state));
}

class _Where(Observer<T> observer, Func<T, TState, bool> predicate, TState state) : Observer<T>
{
protected override void OnNextCore(T value)
{
if (predicate(value, state))
{
observer.OnNext(value);
}
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void OnCompletedCore(Result result)
{
observer.OnCompleted(result);
}
}
}

internal sealed class WhereIndexed<T, TState>(Observable<T> source, Func<T, int, TState, bool> predicate, TState state) : Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
return source.Subscribe(new _Where(observer, predicate, state));
}

class _Where(Observer<T> observer, Func<T, int, TState, bool> predicate, TState state) : Observer<T>
{
int index = 0;

protected override void OnNextCore(T value)
{
if (predicate(value, index++, state))
{
observer.OnNext(value);
}
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void OnCompletedCore(Result result)
{
observer.OnCompleted(result);
}
}
}
2 changes: 1 addition & 1 deletion tests/R3.Tests/FactoryTests/DeferTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public void Test()
var list = def.ToLiveList();

called.Should().BeTrue();

list.AssertEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
}
Expand Down
89 changes: 89 additions & 0 deletions tests/R3.Tests/OperatorTests/SelectTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
namespace R3.Tests.OperatorTests;

public class SelectTest
{
[Fact]
public void Select()
{
var subject = new Subject<int>();
using var list = subject.Select(x => x * 2).ToLiveList();

subject.OnNext(10);
list.AssertEqual([20]);

subject.OnNext(20);
list.AssertEqual([20, 40]);

subject.OnNext(40);
list.AssertEqual([20, 40, 80]);

subject.OnCompleted();
list.AssertIsCompleted();
}

// WhereSelect
[Fact]
public void WhereSelect()
{
var subject = new Subject<int>();
using var list = subject.Where(x => x % 2 == 0).Select(x => x * 2).ToLiveList();

subject.OnNext(10);
list.AssertEqual([20]);

subject.OnNext(11);
list.AssertEqual([20]);

subject.OnNext(20);
list.AssertEqual([20, 40]);

subject.OnNext(40);
list.AssertEqual([20, 40, 80]);

subject.OnNext(99);
list.AssertEqual([20, 40, 80]);

subject.OnCompleted();
list.AssertIsCompleted();
}

// SelectWithIndex
[Fact]
public void SelectWithIndex()
{
var subject = new Subject<int>();
using var list = subject.Select((x, i) => x * 2 + i).ToLiveList();

subject.OnNext(10);
list.AssertEqual([20]);

subject.OnNext(20);
list.AssertEqual([20, 41]);

subject.OnNext(40);
list.AssertEqual([20, 41, 82]);

subject.OnCompleted();
list.AssertIsCompleted();
}

// Select State
[Fact]
public void SelectState()
{
var subject = new Subject<int>();
using var list = subject.Select("a", (x, state) => x * 2 + state).ToLiveList();

subject.OnNext(10);
list.AssertEqual(["20a"]);

subject.OnNext(20);
list.AssertEqual(["20a", "40a"]);

subject.OnNext(40);
list.AssertEqual(["20a", "40a", "80a"]);

subject.OnCompleted();
list.AssertIsCompleted();
}
}
28 changes: 28 additions & 0 deletions tests/R3.Tests/OperatorTests/WhereTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,32 @@ public void WhereCompletableIndexed()

list.AssertIsCompleted();
}

// test where with state
[Fact]
public void WhereState()
{
var p = new Subject<int>();

var state = new { x = 2, y = 0 };
using var list = p.Where(state, static (x, s) => x % s.x != s.y).ToLiveList();

p.OnNext(2);
list.AssertEqual([]);

p.OnNext(1);
list.AssertEqual([1]);

p.OnNext(3);
list.AssertEqual([1, 3]);

p.OnNext(30);
list.AssertEqual([1, 3]);

list.AssertIsNotCompleted();

p.OnCompleted(default);

list.AssertIsCompleted();
}
}

0 comments on commit 503d002

Please sign in to comment.