Skip to content

Commit

Permalink
fix(taskgroup): failed tasks not metered
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto committed Nov 13, 2024
1 parent 543ed3a commit 688b02c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
6 changes: 4 additions & 2 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {

g.taskWaitGroup.Add(1)

err := g.pool.Go(func() {
err := g.pool.dispatcher.Write(func() error {
defer g.taskWaitGroup.Done()

// Check if the context has been cancelled to prevent running tasks that are not needed
if err := g.future.Context().Err(); err != nil {
g.futureResolver(index, &result[O]{
Err: err,
}, err)
return
return err
}

// Invoke the task
Expand All @@ -122,6 +122,8 @@ func (g *abstractTaskGroup[T, E, O]) submit(task any) {
Output: output,
Err: err,
}, err)

return err
})

if err != nil {
Expand Down
55 changes: 55 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,58 @@ func TestTaskGroupDone(t *testing.T) {

assert.Equal(t, int32(5), executedCount.Load())
}

func TestTaskGroupMetrics(t *testing.T) {
pool := NewPool(1)

group := pool.NewGroup()

for i := 0; i < 9; i++ {
group.Submit(func() {
time.Sleep(1 * time.Millisecond)
})
}

// The last task will return an error
sampleErr := errors.New("sample error")
group.SubmitErr(func() error {
time.Sleep(1 * time.Millisecond)
return sampleErr
})

err := group.Wait()

time.Sleep(10 * time.Millisecond)

assert.Equal(t, sampleErr, err)
assert.Equal(t, uint64(10), pool.SubmittedTasks())
assert.Equal(t, uint64(9), pool.SuccessfulTasks())
assert.Equal(t, uint64(1), pool.FailedTasks())
}

func TestTaskGroupMetricsWithCancelledContext(t *testing.T) {
pool := NewPool(1)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

group := pool.NewGroupContext(ctx)

for i := 0; i < 10; i++ {
i := i
group.Submit(func() {
time.Sleep(20 * time.Millisecond)
if i == 4 {
cancel()
}
})
}
err := group.Wait()

time.Sleep(10 * time.Millisecond)

assert.Equal(t, err, context.Canceled)
assert.Equal(t, uint64(10), pool.SubmittedTasks())
assert.Equal(t, uint64(5), pool.SuccessfulTasks())
assert.Equal(t, uint64(5), pool.FailedTasks())
}

0 comments on commit 688b02c

Please sign in to comment.