diff --git a/rpc/handlers.go b/rpc/handlers.go index 073a06806..4ae56a9c0 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -6,11 +6,11 @@ import ( "encoding/binary" "encoding/json" "fmt" + "iter" "log" "maps" "math" "strings" - stdsync "sync" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/clients/feeder" @@ -100,8 +100,7 @@ type Handler struct { l1Heads *feed.Feed[*core.L1Head] idgen func() uint64 - mu stdsync.Mutex // protects subscriptions. - subscriptions map[uint64]*subscription + subscriptions subscriptions blockTraceCache *lru.Cache[traceCacheKey, []TracedBlockTransaction] @@ -112,6 +111,106 @@ type Handler struct { coreContractABI abi.ABI } +func newSubscriptions() subscriptions { + return subscriptions{ + id2sub: make(map[uint64]*subscription), + event: make(chan event), + valueChan: make(chan value), + valuesChan: make(chan iter.Seq[*subscription]), + } +} + +type subscriptions struct { + id2sub map[uint64]*subscription + event chan event + valueChan chan value + valuesChan chan iter.Seq[*subscription] +} + +type value struct { + sub *subscription + ok bool +} + +type event struct { + id uint64 + sub *subscription + action action +} + +type action uint8 + +const ( + add action = iota + remove + getValue + getValues +) + +func (s subscriptions) add(id uint64, sub *subscription) { + fmt.Println("add") + s.event <- event{id: id, sub: sub, action: add} + fmt.Println("added") +} + +func (s subscriptions) remove(id uint64) { + fmt.Println("remove") + s.event <- event{id: id, sub: nil, action: remove} + fmt.Println("removed") +} + +func (s subscriptions) getValue(id uint64) (*subscription, bool) { + fmt.Println("getValue") + s.event <- event{id: id, action: getValue} + v := <-s.valueChan + fmt.Println("gotValue") + return v.sub, v.ok +} + +func (s subscriptions) getValues() iter.Seq[*subscription] { + fmt.Println("getValues") + s.event <- event{action: getValues} + fmt.Println("gotValues") + return <-s.valuesChan +} + +func (s subscriptions) run(ctx context.Context) { + fmt.Println("run") + go func() { + defer func() { + fmt.Println("defer close valueChan") + close(s.valueChan) + fmt.Println("defer close valuesChan") + close(s.valuesChan) + }() + for { + fmt.Println("select") + select { + case e := <-s.event: + fmt.Println("event") + switch e.action { + case add: + fmt.Println("add in run") + s.id2sub[e.id] = e.sub + case remove: + fmt.Println("remove in run") + delete(s.id2sub, e.id) + case getValue: + fmt.Println("getValue in run") + sub, ok := s.id2sub[e.id] + s.valueChan <- value{sub: sub, ok: ok} + case getValues: + fmt.Println("getValues in run") + s.valuesChan <- maps.Values(s.id2sub) + } + case <-ctx.Done(): + fmt.Println("done") + return + } + } + }() +} + type subscription struct { cancel func() wg conc.WaitGroup @@ -141,7 +240,7 @@ func New(bcReader blockchain.Reader, syncReader sync.Reader, virtualMachine vm.V reorgs: feed.New[*sync.ReorgBlockRange](), pendingTxs: feed.New[[]core.Transaction](), l1Heads: feed.New[*core.L1Head](), - subscriptions: make(map[uint64]*subscription), + subscriptions: newSubscriptions(), blockTraceCache: lru.NewCache[traceCacheKey, []TracedBlockTransaction](traceCacheSize), filterLimit: math.MaxUint, @@ -181,6 +280,7 @@ func (h *Handler) WithGateway(gatewayClient Gateway) *Handler { } func (h *Handler) Run(ctx context.Context) error { + h.subscriptions.run(ctx) newHeadsSub := h.syncReader.SubscribeNewHeads().Subscription reorgsSub := h.syncReader.SubscribeReorg().Subscription pendingTxsSub := h.syncReader.SubscribePendingTxs().Subscription @@ -196,11 +296,7 @@ func (h *Handler) Run(ctx context.Context) error { <-ctx.Done() - h.mu.Lock() - subscriptions := maps.Values(h.subscriptions) - h.mu.Unlock() - - for sub := range subscriptions { + for sub := range h.subscriptions.getValues() { sub.wg.Wait() } diff --git a/rpc/subscriptions_test.go b/rpc/subscriptions_test.go index 6676d520f..469ad06a2 100644 --- a/rpc/subscriptions_test.go +++ b/rpc/subscriptions_test.go @@ -47,101 +47,79 @@ func (fc *fakeConn) Equal(other jsonrpc.Conn) bool { } func TestSubscribeEvents(t *testing.T) { - log := utils.NewNopZapLogger() - - t.Run("Return error if too many keys in filter", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) - - keys := make([][]felt.Felt, 1024+1) - fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) - - serverConn, _ := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - }) - - subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, nil) - assert.Zero(t, id) - assert.Equal(t, ErrTooManyKeysInFilter, rpcErr) - }) - - t.Run("Return error if called on pending block", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) - - keys := make([][]felt.Felt, 1) - fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) - blockID := &BlockID{Pending: true} - - serverConn, _ := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - }) - - subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) - assert.Zero(t, id) - assert.Equal(t, ErrCallOnPending, rpcErr) - }) - - t.Run("Return error if block is too far back", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) - - keys := make([][]felt.Felt, 1) - fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) - blockID := &BlockID{Number: 0} - - serverConn, _ := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - }) + tests := []struct { + name string + keys [][]felt.Felt + blockID *BlockID + mockBehaviour func(mockChain *mocks.MockReader) + expectErr *jsonrpc.Error + }{ + { + name: "Return error if too many keys in filter", + keys: make([][]felt.Felt, 1024+1), + expectErr: ErrTooManyKeysInFilter, + }, + { + name: "Return error if called on pending block", + keys: make([][]felt.Felt, 1), + blockID: &BlockID{Pending: true}, + expectErr: ErrCallOnPending, + }, + { + name: "Return error if block is too far back, head is 1024", + keys: make([][]felt.Felt, 1), + blockID: &BlockID{Number: 0}, + mockBehaviour: func(mockChain *mocks.MockReader) { + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 1024}, nil) + mockChain.EXPECT().BlockHeaderByNumber(uint64(0)).Return(&core.Header{Number: 0}, nil) + }, + expectErr: ErrTooManyBlocksBack, + }, + { + name: "Return error if block is too far back, head is more than 1024", + keys: make([][]felt.Felt, 1), + blockID: &BlockID{Number: 0}, + mockBehaviour: func(mockChain *mocks.MockReader) { + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 2024}, nil) + mockChain.EXPECT().BlockHeaderByNumber(uint64(0)).Return(&core.Header{Number: 0}, nil) + }, + expectErr: ErrTooManyBlocksBack, + }, + } - subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) - // Note the end of the window doesn't need to be tested because if requested block number is more than the - // head, a block not found error will be returned. This behaviour has been tested in various other tests, and we - // don't need to test it here again. - t.Run("head is 1024", func(t *testing.T) { - mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 1024}, nil) - mockChain.EXPECT().BlockHeaderByNumber(blockID.Number).Return(&core.Header{Number: 0}, nil) + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + if test.mockBehaviour != nil { + test.mockBehaviour(mockChain) + } + handler := New(mockChain, mockSyncer, nil, "", utils.NewNopZapLogger()) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) - assert.Zero(t, id) - assert.Equal(t, ErrTooManyBlocksBack, rpcErr) - }) + serverConn, _ := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + }) - t.Run("head is more than 1024", func(t *testing.T) { - mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 2024}, nil) - mockChain.EXPECT().BlockHeaderByNumber(blockID.Number).Return(&core.Header{Number: 0}, nil) + subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) + id, rpcErr := handler.SubscribeEvents(subCtx, nil, test.keys, test.blockID) assert.Zero(t, id) - assert.Equal(t, ErrTooManyBlocksBack, rpcErr) + assert.Equal(t, test.expectErr, rpcErr) }) - }) + } +} - n := utils.Ptr(utils.Sepolia) - client := feeder.NewTestClient(t, n) +func TestSubscribeEventsWithClient(t *testing.T) { + log := utils.NewNopZapLogger() + network := utils.Ptr(utils.Sepolia) + client := feeder.NewTestClient(t, network) gw := adaptfeeder.New(client) - b1, err := gw.BlockByNumber(context.Background(), 56377) + block, err := gw.BlockByNumber(context.Background(), 56377) require.NoError(t, err) fromAddr := new(felt.Felt).SetBytes([]byte("some address")) @@ -149,181 +127,87 @@ func TestSubscribeEvents(t *testing.T) { filteredEvents := []*blockchain.FilteredEvent{ { - Event: b1.Receipts[0].Events[0], - BlockNumber: b1.Number, + Event: block.Receipts[0].Events[0], + BlockNumber: block.Number, BlockHash: new(felt.Felt).SetBytes([]byte("b1")), - TransactionHash: b1.Transactions[0].Hash(), + TransactionHash: block.Transactions[0].Hash(), }, { - Event: b1.Receipts[1].Events[0], - BlockNumber: b1.Number + 1, + Event: block.Receipts[1].Events[0], + BlockNumber: block.Number + 1, BlockHash: new(felt.Felt).SetBytes([]byte("b2")), - TransactionHash: b1.Transactions[1].Hash(), + TransactionHash: block.Transactions[1].Hash(), }, } - var emittedEvents []*EmittedEvent - for _, e := range filteredEvents { - emittedEvents = append(emittedEvents, &EmittedEvent{ - Event: &Event{ - From: e.From, - Keys: e.Keys, - Data: e.Data, - }, - BlockHash: e.BlockHash, - BlockNumber: &e.BlockNumber, - TransactionHash: e.TransactionHash, - }) - } - - t.Run("Events from old blocks", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) - - mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) - mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) - mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) - - mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) - mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(filteredEvents, nil, nil) - mockEventFilterer.EXPECT().Close().AnyTimes() - - serverConn, clientConn := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - require.NoError(t, clientConn.Close()) - }) - - ctx, cancel := context.WithCancel(context.Background()) - subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: b1.Number}) - require.Nil(t, rpcErr) - - var marshalledResponses [][]byte - for _, e := range emittedEvents { - resp, err := marshalSubEventsResp(e, id.ID) - require.NoError(t, err) - marshalledResponses = append(marshalledResponses, resp) - } - - for _, m := range marshalledResponses { - got := make([]byte, len(m)) - _, err := clientConn.Read(got) - require.NoError(t, err) - assert.Equal(t, string(m), string(got)) + emittedEvents := make([]*EmittedEvent, len(filteredEvents)) + for i, e := range filteredEvents { + emittedEvents[i] = &EmittedEvent{ + Event: &Event{From: e.From, Keys: e.Keys, Data: e.Data}, + BlockHash: e.BlockHash, BlockNumber: &e.BlockNumber, TransactionHash: e.TransactionHash, } - cancel() - }) - - t.Run("Events when continuation token is not nil", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) + } - mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) - mockChain.EXPECT().BlockHeaderByNumber(b1.Number).Return(b1.Header, nil) - mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + tests := []struct { + name string + continuation *blockchain.ContinuationToken + expectedEvents []*blockchain.FilteredEvent + }{ + { + name: "Events from old blocks", + expectedEvents: filteredEvents, + }, + { + name: "Events with continuation token", + continuation: new(blockchain.ContinuationToken), + expectedEvents: filteredEvents[0:1], + }, + { + name: "Events from new blocks", + expectedEvents: filteredEvents[1:], + }, + } - cToken := new(blockchain.ContinuationToken) - mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) - mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return( - []*blockchain.FilteredEvent{filteredEvents[0]}, cToken, nil) - mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return( - []*blockchain.FilteredEvent{filteredEvents[1]}, nil, nil) - mockEventFilterer.EXPECT().Close().AnyTimes() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) - serverConn, clientConn := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - require.NoError(t, clientConn.Close()) - }) + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) + handler := New(mockChain, mockSyncer, nil, "", log) - ctx, cancel := context.WithCancel(context.Background()) - subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: b1.Number}) - require.Nil(t, rpcErr) + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: block.Number}, nil) + mockChain.EXPECT().BlockHeaderByNumber(block.Number).Return(block.Header, nil) + mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) - var marshalledResponses [][]byte - for _, e := range emittedEvents { - resp, err := marshalSubEventsResp(e, id.ID) - require.NoError(t, err) - marshalledResponses = append(marshalledResponses, resp) - } + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return(tt.expectedEvents, tt.continuation, nil).AnyTimes() + mockEventFilterer.EXPECT().Close().AnyTimes() - for _, m := range marshalledResponses { - got := make([]byte, len(m)) - _, err := clientConn.Read(got) - require.NoError(t, err) - assert.Equal(t, string(m), string(got)) - } - cancel() - }) - - t.Run("Events from new blocks", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) - - handler := New(mockChain, mockSyncer, nil, "", log) - headerFeed := feed.New[*core.Header]() - handler.newHeads = headerFeed - - mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) - mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) - mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) - mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain.FilteredEvent{filteredEvents[0]}, nil, nil) - mockEventFilterer.EXPECT().Close().AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: block.Number}) + require.Nil(t, rpcErr) - serverConn, clientConn := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - require.NoError(t, clientConn.Close()) + for _, e := range emittedEvents { + resp, err := marshalSubEventsResp(e, id.ID) + require.NoError(t, err) + got := make([]byte, len(resp)) + _, err = clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(resp), string(got)) + } + cancel() }) - - ctx, cancel := context.WithCancel(context.Background()) - subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, nil) - require.Nil(t, rpcErr) - - resp, err := marshalSubEventsResp(emittedEvents[0], id.ID) - require.NoError(t, err) - - got := make([]byte, len(resp)) - _, err = clientConn.Read(got) - require.NoError(t, err) - assert.Equal(t, string(resp), string(got)) - - mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) - - mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) - mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain.FilteredEvent{filteredEvents[1]}, nil, nil) - - headerFeed.Send(&core.Header{Number: b1.Number + 1}) - - resp, err = marshalSubEventsResp(emittedEvents[1], id.ID) - require.NoError(t, err) - - got = make([]byte, len(resp)) - _, err = clientConn.Read(got) - require.NoError(t, err) - assert.Equal(t, string(resp), string(got)) - - cancel() - time.Sleep(100 * time.Millisecond) - }) + } } func TestSubscribeTxnStatus(t *testing.T) {