From 9f9f90f8c0b8cb1064c45ee40290e2641e8e0d23 Mon Sep 17 00:00:00 2001 From: reugn Date: Sat, 17 Aug 2024 21:59:47 +0300 Subject: [PATCH] test(flow): improve coverage and minor refactoring --- .golangci.yml | 1 + flow/batch.go | 36 ++++--- flow/batch_test.go | 42 ++++++-- flow/filter_test.go | 72 +++++++++++++ flow/flat_map_test.go | 94 ++++++++++++++++ flow/flow_test.go | 200 ++++++++++++++++++++++------------- flow/map.go | 1 + flow/map_test.go | 92 ++++++++++++++++ flow/reduce.go | 5 +- flow/reduce_test.go | 64 +++++++++++ flow/session_window_test.go | 44 ++++++-- flow/sliding_window_test.go | 75 +++++++++++-- flow/tumbling_window_test.go | 39 ++++++- flow/util.go | 14 +-- 14 files changed, 654 insertions(+), 125 deletions(-) create mode 100644 flow/filter_test.go create mode 100644 flow/flat_map_test.go create mode 100644 flow/map_test.go create mode 100644 flow/reduce_test.go diff --git a/.golangci.yml b/.golangci.yml index 6908a64..22c53d0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -40,3 +40,4 @@ issues: - errcheck - unparam - prealloc + - funlen diff --git a/flow/batch.go b/flow/batch.go index 6e18140..36fbc2f 100644 --- a/flow/batch.go +++ b/flow/batch.go @@ -8,8 +8,8 @@ import ( ) // Batch processor breaks a stream of elements into batches based on size or timing. -// When the maximum batch size is reached or the batch time is elapsed, and the current buffer -// is not empty, a new batch will be emitted. +// When the maximum batch size is reached or the batch time is elapsed, and the +// current buffer is not empty, a new batch will be emitted. // Note: once a batch is sent downstream, the timer will be reset. // T indicates the incoming element type, and the outgoing element type is []T. type Batch[T any] struct { @@ -17,14 +17,16 @@ type Batch[T any] struct { timeInterval time.Duration in chan any out chan any + buffer []T } // Verify Batch satisfies the Flow interface. var _ streams.Flow = (*Batch[any])(nil) -// NewBatch returns a new Batch operator using the specified maximum batch size and the -// time interval. +// NewBatch returns a new Batch operator using the specified maximum batch size and +// the time interval. // T specifies the incoming element type, and the outgoing element type is []T. +// // NewBatch will panic if the maxBatchSize argument is not positive. func NewBatch[T any](maxBatchSize int, timeInterval time.Duration) *Batch[T] { if maxBatchSize < 1 { @@ -35,7 +37,10 @@ func NewBatch[T any](maxBatchSize int, timeInterval time.Duration) *Batch[T] { timeInterval: timeInterval, in: make(chan any), out: make(chan any), + buffer: make([]T, 0, maxBatchSize), } + + // start stream processing go batchFlow.batchStream() return batchFlow @@ -76,34 +81,37 @@ func (b *Batch[T]) batchStream() { ticker := time.NewTicker(b.timeInterval) defer ticker.Stop() - batch := make([]T, 0, b.maxBatchSize) for { select { case element, ok := <-b.in: if ok { - batch = append(batch, element.(T)) + b.buffer = append(b.buffer, element.(T)) // dispatch the batch if the maximum batch size has been reached - if len(batch) >= b.maxBatchSize { - b.out <- batch - batch = make([]T, 0, b.maxBatchSize) + if len(b.buffer) >= b.maxBatchSize { + b.flush() } // reset the ticker ticker.Reset(b.timeInterval) } else { // send the available buffer elements as a new batch, close the // output channel and return - if len(batch) > 0 { - b.out <- batch + if len(b.buffer) > 0 { + b.flush() } close(b.out) return } case <-ticker.C: // timeout; dispatch and reset the buffer - if len(batch) > 0 { - b.out <- batch - batch = make([]T, 0, b.maxBatchSize) + if len(b.buffer) > 0 { + b.flush() } } } } + +// flush sends the elements in the buffer downstream and resets the buffer. +func (b *Batch[T]) flush() { + b.out <- b.buffer + b.buffer = make([]T, 0, b.maxBatchSize) +} diff --git a/flow/batch_test.go b/flow/batch_test.go index db2cb34..29cebf9 100644 --- a/flow/batch_test.go +++ b/flow/batch_test.go @@ -31,14 +31,10 @@ func TestBatch(t *testing.T) { go func() { source. Via(batch). - Via(flow.NewMap(retransmitStringSlice, 1)). // test generic return type To(sink) }() - var outputValues [][]string - for e := range sink.Out { - outputValues = append(outputValues, e.([]string)) - } + outputValues := readSlice[[]string](sink.Out) fmt.Println(outputValues) assert.Equal(t, 3, len(outputValues)) // [[a b c d] [e f g] [h]] @@ -48,7 +44,41 @@ func TestBatch(t *testing.T) { assert.Equal(t, []string{"h"}, outputValues[2]) } -func TestBatchInvalidArguments(t *testing.T) { +func TestBatch_Ptr(t *testing.T) { + in := make(chan any) + out := make(chan any) + + source := ext.NewChanSource(in) + batch := flow.NewBatch[*string](4, 40*time.Millisecond) + sink := ext.NewChanSink(out) + assert.NotEqual(t, batch.Out(), nil) + + inputValues := ptrSlice([]string{"a", "b", "c", "d", "e", "f", "g"}) + go func() { + for _, e := range inputValues { + ingestDeferred(e, in, 5*time.Millisecond) + } + }() + go ingestDeferred(ptr("h"), in, 90*time.Millisecond) + go closeDeferred(in, 100*time.Millisecond) + + go func() { + source. + Via(batch). + To(sink) + }() + + outputValues := readSlice[[]*string](sink.Out) + fmt.Println(outputValues) + + assert.Equal(t, 3, len(outputValues)) // [[a b c d] [e f g] [h]] + + assert.Equal(t, ptrSlice([]string{"a", "b", "c", "d"}), outputValues[0]) + assert.Equal(t, ptrSlice([]string{"e", "f", "g"}), outputValues[1]) + assert.Equal(t, ptrSlice([]string{"h"}), outputValues[2]) +} + +func TestBatch_InvalidArguments(t *testing.T) { assert.Panics(t, func() { flow.NewBatch[string](0, time.Second) }) diff --git a/flow/filter_test.go b/flow/filter_test.go new file mode 100644 index 0000000..dcc7010 --- /dev/null +++ b/flow/filter_test.go @@ -0,0 +1,72 @@ +package flow_test + +import ( + "testing" + + "github.com/reugn/go-streams" + ext "github.com/reugn/go-streams/extension" + "github.com/reugn/go-streams/flow" + "github.com/reugn/go-streams/internal/assert" +) + +func TestFilter(t *testing.T) { + tests := []struct { + name string + filterFlow streams.Flow + ptr bool + }{ + { + name: "values", + filterFlow: flow.NewFilter(func(e int) bool { + return e%2 != 0 + }, 1), + ptr: false, + }, + { + name: "pointers", + filterFlow: flow.NewFilter(func(e *int) bool { + return *e%2 != 0 + }, 1), + ptr: true, + }, + } + input := []int{1, 2, 3, 4, 5} + expected := []int{1, 3, 5} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := make(chan any, 5) + out := make(chan any, 5) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + + if tt.ptr { + ingestSlice(ptrSlice(input), in) + } else { + ingestSlice(input, in) + } + close(in) + + source. + Via(tt.filterFlow). + To(sink) + + if tt.ptr { + output := readSlicePtr[int](out) + assert.Equal(t, ptrSlice(expected), output) + } else { + output := readSlice[int](out) + assert.Equal(t, expected, output) + } + }) + } +} + +func TestFilter_NonPositiveParallelism(t *testing.T) { + assert.Panics(t, func() { + flow.NewFilter(filterNotContainsA, 0) + }) + assert.Panics(t, func() { + flow.NewFilter(filterNotContainsA, -1) + }) +} diff --git a/flow/flat_map_test.go b/flow/flat_map_test.go new file mode 100644 index 0000000..666d882 --- /dev/null +++ b/flow/flat_map_test.go @@ -0,0 +1,94 @@ +package flow_test + +import ( + "strings" + "testing" + + "github.com/reugn/go-streams" + ext "github.com/reugn/go-streams/extension" + "github.com/reugn/go-streams/flow" + "github.com/reugn/go-streams/internal/assert" +) + +func TestFlatMap(t *testing.T) { + tests := []struct { + name string + flatMapFlow streams.Flow + inPtr bool + outPtr bool + }{ + { + name: "val-val", + inPtr: false, + flatMapFlow: flow.NewFlatMap(func(in string) []string { + return []string{in, strings.ToUpper(in)} + }, 1), + outPtr: false, + }, + { + name: "ptr-val", + inPtr: true, + flatMapFlow: flow.NewFlatMap(func(in *string) []string { + return []string{*in, strings.ToUpper(*in)} + }, 1), + outPtr: false, + }, + { + name: "ptr-ptr", + inPtr: true, + flatMapFlow: flow.NewFlatMap(func(in *string) []*string { + upper := strings.ToUpper(*in) + return []*string{in, &upper} + }, 1), + outPtr: true, + }, + { + name: "val-ptr", + inPtr: false, + flatMapFlow: flow.NewFlatMap(func(in string) []*string { + upper := strings.ToUpper(in) + return []*string{&in, &upper} + }, 1), + outPtr: true, + }, + } + input := []string{"a", "b", "c"} + expected := []string{"a", "A", "b", "B", "c", "C"} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := make(chan any, 3) + out := make(chan any, 6) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + + if tt.inPtr { + ingestSlice(ptrSlice(input), in) + } else { + ingestSlice(input, in) + } + close(in) + + source. + Via(tt.flatMapFlow). + To(sink) + + if tt.outPtr { + output := readSlicePtr[string](out) + assert.Equal(t, ptrSlice(expected), output) + } else { + output := readSlice[string](out) + assert.Equal(t, expected, output) + } + }) + } +} + +func TestFlatMap_NonPositiveParallelism(t *testing.T) { + assert.Panics(t, func() { + flow.NewFlatMap(addAsterisk, 0) + }) + assert.Panics(t, func() { + flow.NewFlatMap(addAsterisk, -1) + }) +} diff --git a/flow/flow_test.go b/flow/flow_test.go index 4f2ab71..9271575 100644 --- a/flow/flow_test.go +++ b/flow/flow_test.go @@ -7,11 +7,36 @@ import ( "testing" "time" + "github.com/reugn/go-streams" ext "github.com/reugn/go-streams/extension" "github.com/reugn/go-streams/flow" "github.com/reugn/go-streams/internal/assert" ) +func ptr[T any](value T) *T { + return &value +} + +func ptrSlice[T any](slice []T) []*T { + result := make([]*T, len(slice)) + for i, e := range slice { + result[i] = ptr(e) + } + return result +} + +func ptrInnerSlice[T any](slice [][]T) [][]*T { + outer := make([][]*T, len(slice)) + for i, s := range slice { + inner := make([]*T, len(s)) + for j, e := range s { + inner[j] = ptr(e) + } + outer[i] = inner + } + return outer +} + var addAsterisk = func(in string) []string { resultSlice := make([]string, 2) resultSlice[0] = in + "*" @@ -23,14 +48,6 @@ var filterNotContainsA = func(in string) bool { return !strings.ContainsAny(in, "aA") } -var reduceSum = func(a int, b int) int { - return a + b -} - -var retransmitStringSlice = func(in []string) []string { - return in -} - var mtx sync.Mutex func ingestSlice[T any](source []T, in chan any) { @@ -55,6 +72,22 @@ func closeDeferred[T any](in chan T, wait time.Duration) { close(in) } +func readSlice[T any](ch <-chan any) []T { + var result []T + for e := range ch { + result = append(result, e.(T)) + } + return result +} + +func readSlicePtr[T any](ch <-chan any) []*T { + var result []*T + for e := range ch { + result = append(result, e.(*T)) + } + return result +} + func TestComplexFlow(t *testing.T) { in := make(chan any) out := make(chan any) @@ -83,12 +116,9 @@ func TestComplexFlow(t *testing.T) { To(sink) }() - var outputValues []string - for e := range sink.Out { - outputValues = append(outputValues, e.(string)) - } - + outputValues := readSlice[string](sink.Out) expectedValues := []string{"B*", "B**", "C*", "C**"} + assert.Equal(t, expectedValues, outputValues) } @@ -111,13 +141,45 @@ func TestSplitFlow(t *testing.T) { flow.Merge(split[0], split[1]). To(sink) + outputValues := readSlice[string](sink.Out) + sort.Strings(outputValues) + expectedValues := []string{"A", "B", "C"} + + assert.Equal(t, expectedValues, outputValues) +} + +func TestSplitFlow_Ptr(t *testing.T) { + in := make(chan any, 3) + out := make(chan any, 3) + + source := ext.NewChanSource(in) + toUpperMapFlow := flow.NewMap(func(s *string) *string { + upper := strings.ToUpper(*s) + return &upper + }, 1) + sink := ext.NewChanSink(out) + + inputValues := ptrSlice([]string{"a", "b", "c"}) + ingestSlice(inputValues, in) + close(in) + + split := flow.Split( + source.Via(toUpperMapFlow), + func(in *string) bool { + return !strings.ContainsAny(*in, "aA") + }) + + flow.Merge(split[0], split[1]). + To(sink) + var outputValues []string for e := range sink.Out { - outputValues = append(outputValues, e.(string)) + v := e.(*string) + outputValues = append(outputValues, *v) } sort.Strings(outputValues) - expectedValues := []string{"A", "B", "C"} + assert.Equal(t, expectedValues, outputValues) } @@ -144,13 +206,10 @@ func TestFanOutFlow(t *testing.T) { To(sink) }() - var outputValues []string - for e := range sink.Out { - outputValues = append(outputValues, e.(string)) - } + outputValues := readSlice[string](sink.Out) sort.Strings(outputValues) - expectedValues := []string{"B", "B", "C", "C"} + assert.Equal(t, expectedValues, outputValues) } @@ -177,65 +236,58 @@ func TestRoundRobinFlow(t *testing.T) { To(sink) }() - var outputValues []string - for e := range sink.Out { - outputValues = append(outputValues, e.(string)) - } + outputValues := readSlice[string](sink.Out) sort.Strings(outputValues) - expectedValues := []string{"B", "C"} - assert.Equal(t, expectedValues, outputValues) -} -func TestReduceFlow(t *testing.T) { - in := make(chan any, 5) - out := make(chan any, 5) - - source := ext.NewChanSource(in) - reduceFlow := flow.NewReduce(reduceSum) - sink := ext.NewChanSink(out) - - inputValues := []int{1, 2, 3, 4, 5} - ingestSlice(inputValues, in) - close(in) - - source. - Via(reduceFlow). - Via(flow.NewPassThrough()). - To(sink) - - var outputValues []int - for e := range sink.Out { - outputValues = append(outputValues, e.(int)) - } - - expectedValues := []int{1, 3, 6, 10, 15} assert.Equal(t, expectedValues, outputValues) } -func TestFilterNonPositiveParallelism(t *testing.T) { - assert.Panics(t, func() { - flow.NewFilter(filterNotContainsA, 0) - }) - assert.Panics(t, func() { - flow.NewFilter(filterNotContainsA, -1) - }) -} - -func TestFlatMapNonPositiveParallelism(t *testing.T) { - assert.Panics(t, func() { - flow.NewFlatMap(addAsterisk, 0) - }) - assert.Panics(t, func() { - flow.NewFlatMap(addAsterisk, -1) - }) -} +func TestFlatten(t *testing.T) { + tests := []struct { + name string + flattenFlow streams.Flow + ptr bool + }{ + { + name: "values", + flattenFlow: flow.Flatten[int](1), + ptr: false, + }, + { + name: "pointers", + flattenFlow: flow.Flatten[*int](1), + ptr: true, + }, + } + input := [][]int{{1, 2, 3}, {4, 5}} + expected := []int{1, 2, 3, 4, 5} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := make(chan any, 5) + out := make(chan any, 5) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + + if tt.ptr { + ingestSlice(ptrInnerSlice(input), in) + } else { + ingestSlice(input, in) + } + close(in) -func TestMapNonPositiveParallelism(t *testing.T) { - assert.Panics(t, func() { - flow.NewMap(strings.ToUpper, 0) - }) - assert.Panics(t, func() { - flow.NewMap(strings.ToUpper, -1) - }) + source. + Via(tt.flattenFlow). + To(sink) + + if tt.ptr { + output := readSlicePtr[int](out) + assert.Equal(t, ptrSlice(expected), output) + } else { + output := readSlice[int](out) + assert.Equal(t, expected, output) + } + }) + } } diff --git a/flow/map.go b/flow/map.go index 8f91df3..9e6667e 100644 --- a/flow/map.go +++ b/flow/map.go @@ -43,6 +43,7 @@ func NewMap[T, R any](mapFunction MapFunction[T, R], parallelism int) *Map[T, R] parallelism: parallelism, } go mapFlow.doStream() + return mapFlow } diff --git a/flow/map_test.go b/flow/map_test.go new file mode 100644 index 0000000..c2623ec --- /dev/null +++ b/flow/map_test.go @@ -0,0 +1,92 @@ +package flow_test + +import ( + "strings" + "testing" + + "github.com/reugn/go-streams" + ext "github.com/reugn/go-streams/extension" + "github.com/reugn/go-streams/flow" + "github.com/reugn/go-streams/internal/assert" +) + +func TestMap(t *testing.T) { + tests := []struct { + name string + mapFlow streams.Flow + inPtr bool + outPtr bool + }{ + { + name: "val-val", + inPtr: false, + mapFlow: flow.NewMap(strings.ToUpper, 1), + outPtr: false, + }, + { + name: "ptr-val", + inPtr: true, + mapFlow: flow.NewMap(func(in *string) string { + return strings.ToUpper(*in) + }, 1), + outPtr: false, + }, + { + name: "ptr-ptr", + inPtr: true, + mapFlow: flow.NewMap(func(in *string) *string { + result := strings.ToUpper(*in) + return &result + }, 1), + outPtr: true, + }, + { + name: "val-ptr", + inPtr: false, + mapFlow: flow.NewMap(func(in string) *string { + result := strings.ToUpper(in) + return &result + }, 1), + outPtr: true, + }, + } + input := []string{"a", "b", "c"} + expected := []string{"A", "B", "C"} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := make(chan any, 3) + out := make(chan any, 3) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + + if tt.inPtr { + ingestSlice(ptrSlice(input), in) + } else { + ingestSlice(input, in) + } + close(in) + + source. + Via(tt.mapFlow). + To(sink) + + if tt.outPtr { + output := readSlicePtr[string](out) + assert.Equal(t, ptrSlice(expected), output) + } else { + output := readSlice[string](out) + assert.Equal(t, expected, output) + } + }) + } +} + +func TestMap_NonPositiveParallelism(t *testing.T) { + assert.Panics(t, func() { + flow.NewMap(strings.ToUpper, 0) + }) + assert.Panics(t, func() { + flow.NewMap(strings.ToUpper, -1) + }) +} diff --git a/flow/reduce.go b/flow/reduce.go index 7bf030e..276d0e3 100644 --- a/flow/reduce.go +++ b/flow/reduce.go @@ -36,6 +36,7 @@ func NewReduce[T any](reduceFunction ReduceFunction[T]) *Reduce[T] { out: make(chan any), } go reduce.doStream() + return reduce } @@ -72,7 +73,9 @@ func (r *Reduce[T]) doStream() { if r.lastReduced == nil { r.lastReduced = element } else { - r.lastReduced = r.reduceFunction(r.lastReduced.(T), element.(T)) + r.lastReduced = r.reduceFunction( + r.lastReduced.(T), + element.(T)) } r.out <- r.lastReduced } diff --git a/flow/reduce_test.go b/flow/reduce_test.go new file mode 100644 index 0000000..57b9e8a --- /dev/null +++ b/flow/reduce_test.go @@ -0,0 +1,64 @@ +package flow_test + +import ( + "testing" + + "github.com/reugn/go-streams" + ext "github.com/reugn/go-streams/extension" + "github.com/reugn/go-streams/flow" + "github.com/reugn/go-streams/internal/assert" +) + +func TestReduce(t *testing.T) { + tests := []struct { + name string + reduceFlow streams.Flow + ptr bool + }{ + { + name: "values", + reduceFlow: flow.NewReduce(func(a int, b int) int { + return a + b + }), + ptr: false, + }, + { + name: "pointers", + reduceFlow: flow.NewReduce(func(a *int, b *int) *int { + result := *a + *b + return &result + }), + ptr: true, + }, + } + input := []int{1, 2, 3, 4, 5} + expected := []int{1, 3, 6, 10, 15} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := make(chan any, 5) + out := make(chan any, 5) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + + if tt.ptr { + ingestSlice(ptrSlice(input), in) + } else { + ingestSlice(input, in) + } + close(in) + + source. + Via(tt.reduceFlow). + To(sink) + + if tt.ptr { + output := readSlicePtr[int](out) + assert.Equal(t, ptrSlice(expected), output) + } else { + output := readSlice[int](out) + assert.Equal(t, expected, output) + } + }) + } +} diff --git a/flow/session_window_test.go b/flow/session_window_test.go index f7cce58..9783706 100644 --- a/flow/session_window_test.go +++ b/flow/session_window_test.go @@ -28,14 +28,10 @@ func TestSessionWindow(t *testing.T) { go func() { source. Via(sessionWindow). - Via(flow.NewMap(retransmitStringSlice, 1)). // test generic return type To(sink) }() - var outputValues [][]string - for e := range sink.Out { - outputValues = append(outputValues, e.([]string)) - } + outputValues := readSlice[[]string](sink.Out) fmt.Println(outputValues) assert.Equal(t, 3, len(outputValues)) // [[a b c] [d] [e]] @@ -45,7 +41,7 @@ func TestSessionWindow(t *testing.T) { assert.Equal(t, []string{"e"}, outputValues[2]) } -func TestLongSessionWindow(t *testing.T) { +func TestSessionWindow_Long(t *testing.T) { in := make(chan any) out := make(chan any) @@ -68,10 +64,7 @@ func TestLongSessionWindow(t *testing.T) { To(sink) }() - var outputValues [][]string - for e := range sink.Out { - outputValues = append(outputValues, e.([]string)) - } + outputValues := readSlice[[]string](sink.Out) fmt.Println(outputValues) assert.Equal(t, 2, len(outputValues)) // [[a b c d e f g] [h]] @@ -79,3 +72,34 @@ func TestLongSessionWindow(t *testing.T) { assert.Equal(t, []string{"a", "b", "c", "d", "e", "f", "g"}, outputValues[0]) assert.Equal(t, []string{"h"}, outputValues[1]) } + +func TestSessionWindow_Ptr(t *testing.T) { + in := make(chan any) + out := make(chan any) + + source := ext.NewChanSource(in) + sessionWindow := flow.NewSessionWindow[*string](20 * time.Millisecond) + sink := ext.NewChanSink(out) + assert.NotEqual(t, sessionWindow.Out(), nil) + + inputValues := ptrSlice([]string{"a", "b", "c"}) + go ingestSlice(inputValues, in) + go ingestDeferred(ptr("d"), in, 30*time.Millisecond) + go ingestDeferred(ptr("e"), in, 70*time.Millisecond) + go closeDeferred(in, 100*time.Millisecond) + + go func() { + source. + Via(sessionWindow). + To(sink) + }() + + outputValues := readSlice[[]*string](sink.Out) + fmt.Println(outputValues) + + assert.Equal(t, 3, len(outputValues)) // [[a b c] [d] [e]] + + assert.Equal(t, ptrSlice([]string{"a", "b", "c"}), outputValues[0]) + assert.Equal(t, ptrSlice([]string{"d"}), outputValues[1]) + assert.Equal(t, ptrSlice([]string{"e"}), outputValues[2]) +} diff --git a/flow/sliding_window_test.go b/flow/sliding_window_test.go index 35c5a64..9ea416d 100644 --- a/flow/sliding_window_test.go +++ b/flow/sliding_window_test.go @@ -30,14 +30,10 @@ func TestSlidingWindow(t *testing.T) { go func() { source. Via(slidingWindow). - Via(flow.NewMap(retransmitStringSlice, 1)). // test generic return type To(sink) }() - var outputValues [][]string - for e := range sink.Out { - outputValues = append(outputValues, e.([]string)) - } + outputValues := readSlice[[]string](sink.Out) fmt.Println(outputValues) assert.Equal(t, 6, len(outputValues)) // [[a b c] [b c d] [c d e] [d e f g] [f g] [g]] @@ -55,7 +51,7 @@ type element struct { ts int64 } -func TestSlidingWindowWithExtractor(t *testing.T) { +func TestSlidingWindow_WithExtractor(t *testing.T) { in := make(chan any) out := make(chan any) @@ -94,7 +90,68 @@ func TestSlidingWindowWithExtractor(t *testing.T) { var outputValues [][]string for e := range sink.Out { - outputValues = append(outputValues, stringValues(e.([]element))) + outputValues = append(outputValues, elementValues(e.([]element))) + } + fmt.Println(outputValues) + + assert.Equal(t, 6, len(outputValues)) // [[a b c d] [c d] [e] [e f] [f g] [i h g]] + + assert.Equal(t, []string{"a", "b", "c", "d"}, outputValues[0]) + assert.Equal(t, []string{"c", "d"}, outputValues[1]) + assert.Equal(t, []string{"e"}, outputValues[2]) + assert.Equal(t, []string{"e", "f"}, outputValues[3]) + assert.Equal(t, []string{"f", "g"}, outputValues[4]) + assert.Equal(t, []string{"i", "h", "g"}, outputValues[5]) +} + +func elementValues(elements []element) []string { + values := make([]string, len(elements)) + for i, e := range elements { + values[i] = e.value + } + return values +} + +func TestSlidingWindow_WithExtractorPtr(t *testing.T) { + in := make(chan any) + out := make(chan any) + + source := ext.NewChanSource(in) + slidingWindow := flow.NewSlidingWindowWithExtractor( + 50*time.Millisecond, + 20*time.Millisecond, + func(e *element) int64 { + return e.ts + }) + sink := ext.NewChanSink(out) + + now := time.Now() + inputValues := []*element{ + {"c", now.Add(29 * time.Millisecond).UnixNano()}, + {"a", now.Add(2 * time.Millisecond).UnixNano()}, + {"b", now.Add(17 * time.Millisecond).UnixNano()}, + {"d", now.Add(35 * time.Millisecond).UnixNano()}, + {"f", now.Add(93 * time.Millisecond).UnixNano()}, + {"e", now.Add(77 * time.Millisecond).UnixNano()}, + {"g", now.Add(120 * time.Millisecond).UnixNano()}, + } + go ingestSlice(inputValues, in) + go closeDeferred(in, 250*time.Millisecond) + // send some out-of-order events + go ingestDeferred(&element{"h", now.Add(5 * time.Millisecond).UnixNano()}, + in, 145*time.Millisecond) + go ingestDeferred(&element{"i", now.Add(3 * time.Millisecond).UnixNano()}, + in, 145*time.Millisecond) + + go func() { + source. + Via(slidingWindow). + To(sink) + }() + + var outputValues [][]string + for e := range sink.Out { + outputValues = append(outputValues, elementValuesPtr(e.([]*element))) } fmt.Println(outputValues) @@ -108,7 +165,7 @@ func TestSlidingWindowWithExtractor(t *testing.T) { assert.Equal(t, []string{"i", "h", "g"}, outputValues[5]) } -func stringValues(elements []element) []string { +func elementValuesPtr(elements []*element) []string { values := make([]string, len(elements)) for i, e := range elements { values[i] = e.value @@ -116,7 +173,7 @@ func stringValues(elements []element) []string { return values } -func TestSlidingWindowInvalidArguments(t *testing.T) { +func TestSlidingWindow_InvalidArguments(t *testing.T) { assert.Panics(t, func() { flow.NewSlidingWindow[string](10*time.Millisecond, 20*time.Millisecond) }) diff --git a/flow/tumbling_window_test.go b/flow/tumbling_window_test.go index fa13f6b..be05462 100644 --- a/flow/tumbling_window_test.go +++ b/flow/tumbling_window_test.go @@ -30,14 +30,10 @@ func TestTumblingWindow(t *testing.T) { go func() { source. Via(tumblingWindow). - Via(flow.NewMap(retransmitStringSlice, 1)). // test generic return type To(sink) }() - var outputValues [][]string - for e := range sink.Out { - outputValues = append(outputValues, e.([]string)) - } + outputValues := readSlice[[]string](sink.Out) fmt.Println(outputValues) assert.Equal(t, 3, len(outputValues)) // [[a b c] [d e f] [g]] @@ -46,3 +42,36 @@ func TestTumblingWindow(t *testing.T) { assert.Equal(t, []string{"d", "e", "f"}, outputValues[1]) assert.Equal(t, []string{"g"}, outputValues[2]) } + +func TestTumblingWindow_Ptr(t *testing.T) { + in := make(chan any) + out := make(chan any) + + source := ext.NewChanSource(in) + tumblingWindow := flow.NewTumblingWindow[*string](50 * time.Millisecond) + sink := ext.NewChanSink(out) + assert.NotEqual(t, tumblingWindow.Out(), nil) + + go func() { + inputValues := ptrSlice([]string{"a", "b", "c", "d", "e", "f", "g"}) + for _, v := range inputValues { + ingestDeferred(v, in, 15*time.Millisecond) + } + closeDeferred(in, 160*time.Millisecond) + }() + + go func() { + source. + Via(tumblingWindow). + To(sink) + }() + + outputValues := readSlice[[]*string](sink.Out) + fmt.Println(outputValues) + + assert.Equal(t, 3, len(outputValues)) // [[a b c] [d e f] [g]] + + assert.Equal(t, ptrSlice([]string{"a", "b", "c"}), outputValues[0]) + assert.Equal(t, ptrSlice([]string{"d", "e", "f"}), outputValues[1]) + assert.Equal(t, ptrSlice([]string{"g"}), outputValues[2]) +} diff --git a/flow/util.go b/flow/util.go index f2c97bf..350b123 100644 --- a/flow/util.go +++ b/flow/util.go @@ -31,6 +31,7 @@ func Split[T any](outlet streams.Outlet, predicate func(T) bool) [2]streams.Flow condFalse.In() <- element } } + close(condTrue.In()) close(condFalse.In()) }() @@ -48,12 +49,12 @@ func FanOut(outlet streams.Outlet, magnitude int) []streams.Flow { go func() { for element := range outlet.Out() { - for _, socket := range out { - socket.In() <- element + for _, flow := range out { + flow.In() <- element } } - for i := 0; i < magnitude; i++ { - close(out[i].In()) + for _, flow := range out { + close(flow.In()) } }() @@ -78,6 +79,7 @@ func RoundRobin(outlet streams.Outlet, magnitude int) []streams.Flow { } // Merge merges multiple flows into a single flow. +// When all specified outlets are closed, the resulting flow will close. func Merge(outlets ...streams.Flow) streams.Flow { merged := NewPassThrough() var wg sync.WaitGroup @@ -93,10 +95,10 @@ func Merge(outlets ...streams.Flow) streams.Flow { } // close the in channel on the last outlet close. - go func(wg *sync.WaitGroup) { + go func() { wg.Wait() close(merged.In()) - }(&wg) + }() return merged }