Skip to content

Commit

Permalink
Add context cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
aertje committed Nov 2, 2024
1 parent 5c1e40d commit ccd72b3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 45 deletions.
65 changes: 43 additions & 22 deletions nice/nice.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package nice

import (
"context"
"runtime"
"sync"

"github.com/aertje/gonice/queue"
)

type entry struct {
priority int
waitChan chan func()
priority int
waitChan chan<- func()
cancelChan <-chan struct{}
}

type Scheduler struct {
Expand Down Expand Up @@ -52,25 +54,32 @@ func (s *Scheduler) assessEntries() {
s.lock.Lock()
defer s.lock.Unlock()

if s.concurrency >= s.maxConcurrency {
return
}
for {
if s.concurrency >= s.maxConcurrency {
return
}

entry, has := s.entries.Pop()
if !has {
return
}
entry, has := s.entries.Pop()
if !has {
return
}

fnDone := func() {
close(entry.waitChan)
s.lock.Lock()
s.concurrency--
s.lock.Unlock()
s.done <- entry
}
fnDone := func() {
s.lock.Lock()
s.concurrency--
s.lock.Unlock()
s.done <- entry
}

entry.waitChan <- fnDone
s.concurrency++
select {
case <-entry.cancelChan:
s.done <- entry
default:
entry.waitChan <- fnDone
close(entry.waitChan)
s.concurrency++
}
}
}

func (s *Scheduler) schedule() {
Expand All @@ -89,15 +98,27 @@ func (s *Scheduler) schedule() {
}()
}

func (s *Scheduler) Wait(priority int) chan func() {
func (s *Scheduler) WaitContext(ctx context.Context, priority int) func() {
waitChan := make(chan func())
cancelChan := make(chan struct{})

entry := entry{
priority: priority,
waitChan: waitChan,
priority: priority,
waitChan: waitChan,
cancelChan: cancelChan,
}

s.incoming <- entry

return waitChan
select {
case <-ctx.Done():
close(cancelChan)
return func() {}
case fnDone := <-waitChan:
return fnDone
}
}

func (s *Scheduler) Wait(priority int) func() {
return s.WaitContext(context.Background(), priority)
}
90 changes: 67 additions & 23 deletions nice/nice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,84 @@ import (
func TestSimple(t *testing.T) {
s := nice.NewScheduler()

fnDone := <-s.Wait(1)
fnDone := s.Wait(1)
fnDone()
}

func TestOrderSingleConcurrency(t *testing.T) {
s := nice.NewScheduler(nice.WithMaxConcurrency(1))
func TestOrderConcurrency(t *testing.T) {
for _, tc := range []struct {
name string
maxConcurrency int
totalTasks int
expectedResult []int
}{
{
name: "no concurrency",
maxConcurrency: 1,
totalTasks: 10,
expectedResult: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
{
name: "concurrency 2",
maxConcurrency: 2,
totalTasks: 10,
expectedResult: []int{1, 1, 2, 2, 3, 3, 4, 4, 5, 5},
},
{
name: "concurrency 8",
maxConcurrency: 8,
totalTasks: 16,
expectedResult: []int{1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2},
},
} {
t.Run(tc.name, func(t *testing.T) {
results := testOrderForConcurrency(tc.maxConcurrency, tc.totalTasks)
assert.Equal(t, tc.expectedResult, results)
})
}
}

waitChan := make(chan struct{})
go func() {
fnDone := <-s.Wait(0)
waitChan <- struct{}{}
time.Sleep(10 * time.Millisecond)
fnDone()
}()
func testOrderForConcurrency(maxConcurrency int, totalTasks int) []int {
s := nice.NewScheduler(nice.WithMaxConcurrency(maxConcurrency))

results := make([]int, 0)
var lock sync.Mutex
var wg sync.WaitGroup

<-waitChan
for i := 10; i > 0; i-- {
wg.Add(1)
// Saturate the scheduler otherwise subsequent tasks will be executed
// immediately in undefined order.
for i := 0; i < maxConcurrency; i++ {
go func() {
defer wg.Done()
fnDone := <-s.Wait(i)
fnDone := s.Wait(0)
time.Sleep(10 * time.Millisecond)
lock.Lock()
results = append(results, i)
defer lock.Unlock()
fnDone()
}()
}

// Give the scheduler some time to start the goroutines.
time.Sleep(1 * time.Millisecond)

results := make([]int, 0)
var lock sync.Mutex
var wg sync.WaitGroup

for i := totalTasks / maxConcurrency; i > 0; i-- {
for j := 0; j < maxConcurrency; j++ {
priority := i
wg.Add(1)
go func() {
defer wg.Done()

fnDone := s.Wait(priority)

time.Sleep(10 * time.Millisecond)

lock.Lock()
results = append(results, i)
defer lock.Unlock()

fnDone()
}()
}
}

wg.Wait()

assert.Equal(t, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, results)
return results
}

0 comments on commit ccd72b3

Please sign in to comment.