diff --git a/client.go b/client.go index 7d95491..dde33a3 100644 --- a/client.go +++ b/client.go @@ -43,6 +43,7 @@ type Client struct { token string data protocol.Raw transport transport + disconnectedCh chan struct{} state State subs map[string]*Subscription serverSubs map[string]*serverSub @@ -138,7 +139,9 @@ func newClient(endpoint string, isProtobuf bool, config Config) *Client { } // Queue to run callbacks on. - client.cbQueue = &cbQueue{} + client.cbQueue = &cbQueue{ + closeCh: make(chan struct{}), + } client.cbQueue.cond = sync.NewCond(&client.cbQueue.mu) go client.cbQueue.dispatch() @@ -535,9 +538,15 @@ func (c *Client) moveToClosed() { } c.mu.Lock() + defer c.mu.Unlock() + // At this point connection close was issued, so we wait until the reader goroutine + // finishes its work, after that it's safe to close the callback queue. + if c.disconnectedCh != nil { + <-c.disconnectedCh + } + c.disconnectedCh = nil c.cbQueue.close() c.cbQueue = nil - c.mu.Unlock() } func (c *Client) handleError(err error) { @@ -959,6 +968,7 @@ func (c *Client) startReconnecting() error { disconnectCh := make(chan struct{}) c.receive = make(chan []byte, 64) c.transport = t + c.disconnectedCh = disconnectCh go c.reader(t, disconnectCh) diff --git a/queue.go b/queue.go index 04f79e2..20595e4 100644 --- a/queue.go +++ b/queue.go @@ -11,10 +11,12 @@ import ( // https://github.com/nats-io/nats.go client released under Apache 2.0 // license: see https://github.com/nats-io/nats.go/blob/master/LICENSE. type cbQueue struct { - mu sync.Mutex - cond *sync.Cond - head *asyncCB - tail *asyncCB + mu sync.Mutex + cond *sync.Cond + head *asyncCB + tail *asyncCB + closeCh chan struct{} + closed bool } type asyncCB struct { @@ -43,6 +45,7 @@ func (q *cbQueue) dispatch() { // This signals that the dispatcher has been closed and all // previous callbacks have been dispatched. if curr.fn == nil { + close(q.closeCh) return } curr.fn(time.Since(curr.tm)) @@ -56,13 +59,22 @@ func (q *cbQueue) push(f func(duration time.Duration)) { } // Close signals that async queue must be closed. +// Queue won't accept any more callbacks after that – ignoring them if pushed. func (q *cbQueue) close() { q.pushOrClose(nil, true) + q.waitClose() +} + +func (q *cbQueue) waitClose() { + <-q.closeCh } func (q *cbQueue) pushOrClose(f func(time.Duration), close bool) { q.mu.Lock() defer q.mu.Unlock() + if q.closed { + return + } // Make sure that library is not calling push with nil function, // since this is used to notify the dispatcher that it must stop. if !close && f == nil { @@ -76,6 +88,7 @@ func (q *cbQueue) pushOrClose(f func(time.Duration), close bool) { } q.tail = cb if close { + q.closed = true q.cond.Broadcast() } else { q.cond.Signal() diff --git a/queue_test.go b/queue_test.go new file mode 100644 index 0000000..eee5a11 --- /dev/null +++ b/queue_test.go @@ -0,0 +1,131 @@ +package centrifuge + +import ( + "sync" + "testing" + "time" +) + +func assertTrue(t *testing.T, condition bool, msg string) { + if !condition { + t.Fatalf("Assertion failed: %s", msg) + } +} + +func assertEqual(t *testing.T, expected, actual interface{}, msg string) { + if expected != actual { + t.Fatalf("Assertion failed: %s - expected: %v, got: %v", msg, expected, actual) + } +} + +func newTestQueue() *cbQueue { + q := &cbQueue{ + closeCh: make(chan struct{}), + } + q.cond = sync.NewCond(&q.mu) + return q +} + +func TestCbQueue_PushAndDispatch(t *testing.T) { + q := newTestQueue() + + var wg sync.WaitGroup + wg.Add(1) + + // Start the dispatcher in a separate goroutine. + go q.dispatch() + + startTime := time.Now() + q.push(func(d time.Duration) { + defer wg.Done() + assertTrue(t, d >= 0, "Callback duration should be positive") + }) + + // Wait for the callback to finish. + wg.Wait() + + // Ensure the callback executed quickly. + elapsed := time.Since(startTime) + assertTrue(t, elapsed < 100*time.Millisecond, "Callback should be dispatched immediately") +} + +func TestCbQueue_OrderPreservation(t *testing.T) { + q := newTestQueue() + + // Start the dispatcher in a separate goroutine. + go q.dispatch() + + var results []int + var mu sync.Mutex + expectedResults := []int{1, 2, 3} + + for _, i := range expectedResults { + i := i + q.push(func(d time.Duration) { + mu.Lock() + defer mu.Unlock() + results = append(results, i) + }) + } + + // Allow time for the queue to process. + time.Sleep(100 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + + for i, r := range results { + assertEqual(t, expectedResults[i], r, "unexpected result") + } +} + +func TestCbQueue_Close(t *testing.T) { + q := newTestQueue() + + go q.dispatch() + + var executed bool + q.push(func(d time.Duration) { + executed = true + }) + + q.close() + + // Ensure the closeCh channel is closed. + select { + case <-q.closeCh: + // Channel was closed as expected. + case <-time.After(1 * time.Second): + t.Fatal("closeCh was not closed after queue close") + } + + assertTrue(t, executed, "Callback should be executed before close") +} + +func TestCbQueue_IgnorePushAfterClose(t *testing.T) { + q := newTestQueue() + go q.dispatch() + q.close() + + var executed bool + q.push(func(d time.Duration) { + executed = true + }) + + // Allow some time to see if the callback is executed. + time.Sleep(100 * time.Millisecond) + + assertTrue(t, !executed, "Callback should not be executed after queue close") +} + +func TestCbQueue_PushNilCallbackPanics(t *testing.T) { + q := newTestQueue() + + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected panic when pushing nil callback with close set to false") + } + }() + + q.pushOrClose(nil, false) +}