From ccd72b318d909a8b100f8a81f167e1edcc2eecff Mon Sep 17 00:00:00 2001 From: Aert van de Hulsbeek Date: Sat, 2 Nov 2024 22:42:33 +1100 Subject: [PATCH] Add context cancellation --- nice/nice.go | 65 ++++++++++++++++++++++------------ nice/nice_test.go | 90 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 110 insertions(+), 45 deletions(-) diff --git a/nice/nice.go b/nice/nice.go index 7a455c4..cdd429a 100644 --- a/nice/nice.go +++ b/nice/nice.go @@ -1,6 +1,7 @@ package nice import ( + "context" "runtime" "sync" @@ -8,8 +9,9 @@ import ( ) type entry struct { - priority int - waitChan chan func() + priority int + waitChan chan<- func() + cancelChan <-chan struct{} } type Scheduler struct { @@ -52,25 +54,32 @@ func (s *Scheduler) assessEntries() { s.lock.Lock() defer s.lock.Unlock() - if s.concurrency >= s.maxConcurrency { - return - } + for { + if s.concurrency >= s.maxConcurrency { + return + } - entry, has := s.entries.Pop() - if !has { - return - } + entry, has := s.entries.Pop() + if !has { + return + } - fnDone := func() { - close(entry.waitChan) - s.lock.Lock() - s.concurrency-- - s.lock.Unlock() - s.done <- entry - } + fnDone := func() { + s.lock.Lock() + s.concurrency-- + s.lock.Unlock() + s.done <- entry + } - entry.waitChan <- fnDone - s.concurrency++ + select { + case <-entry.cancelChan: + s.done <- entry + default: + entry.waitChan <- fnDone + close(entry.waitChan) + s.concurrency++ + } + } } func (s *Scheduler) schedule() { @@ -89,15 +98,27 @@ func (s *Scheduler) schedule() { }() } -func (s *Scheduler) Wait(priority int) chan func() { +func (s *Scheduler) WaitContext(ctx context.Context, priority int) func() { waitChan := make(chan func()) + cancelChan := make(chan struct{}) entry := entry{ - priority: priority, - waitChan: waitChan, + priority: priority, + waitChan: waitChan, + cancelChan: cancelChan, } s.incoming <- entry - return waitChan + select { + case <-ctx.Done(): + close(cancelChan) + return func() {} + case fnDone := <-waitChan: + return fnDone + } +} + +func (s *Scheduler) Wait(priority int) func() { + return s.WaitContext(context.Background(), priority) } diff --git a/nice/nice_test.go b/nice/nice_test.go index 8a97661..0f20de7 100644 --- a/nice/nice_test.go +++ b/nice/nice_test.go @@ -12,40 +12,84 @@ import ( func TestSimple(t *testing.T) { s := nice.NewScheduler() - fnDone := <-s.Wait(1) + fnDone := s.Wait(1) fnDone() } -func TestOrderSingleConcurrency(t *testing.T) { - s := nice.NewScheduler(nice.WithMaxConcurrency(1)) +func TestOrderConcurrency(t *testing.T) { + for _, tc := range []struct { + name string + maxConcurrency int + totalTasks int + expectedResult []int + }{ + { + name: "no concurrency", + maxConcurrency: 1, + totalTasks: 10, + expectedResult: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + { + name: "concurrency 2", + maxConcurrency: 2, + totalTasks: 10, + expectedResult: []int{1, 1, 2, 2, 3, 3, 4, 4, 5, 5}, + }, + { + name: "concurrency 8", + maxConcurrency: 8, + totalTasks: 16, + expectedResult: []int{1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + results := testOrderForConcurrency(tc.maxConcurrency, tc.totalTasks) + assert.Equal(t, tc.expectedResult, results) + }) + } +} - waitChan := make(chan struct{}) - go func() { - fnDone := <-s.Wait(0) - waitChan <- struct{}{} - time.Sleep(10 * time.Millisecond) - fnDone() - }() +func testOrderForConcurrency(maxConcurrency int, totalTasks int) []int { + s := nice.NewScheduler(nice.WithMaxConcurrency(maxConcurrency)) - results := make([]int, 0) - var lock sync.Mutex - var wg sync.WaitGroup - - <-waitChan - for i := 10; i > 0; i-- { - wg.Add(1) + // Saturate the scheduler otherwise subsequent tasks will be executed + // immediately in undefined order. + for i := 0; i < maxConcurrency; i++ { go func() { - defer wg.Done() - fnDone := <-s.Wait(i) + fnDone := s.Wait(0) time.Sleep(10 * time.Millisecond) - lock.Lock() - results = append(results, i) - defer lock.Unlock() fnDone() }() } + // Give the scheduler some time to start the goroutines. + time.Sleep(1 * time.Millisecond) + + results := make([]int, 0) + var lock sync.Mutex + var wg sync.WaitGroup + + for i := totalTasks / maxConcurrency; i > 0; i-- { + for j := 0; j < maxConcurrency; j++ { + priority := i + wg.Add(1) + go func() { + defer wg.Done() + + fnDone := s.Wait(priority) + + time.Sleep(10 * time.Millisecond) + + lock.Lock() + results = append(results, i) + defer lock.Unlock() + + fnDone() + }() + } + } + wg.Wait() - assert.Equal(t, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, results) + return results }