diff --git a/rpc/subscriptions.go b/rpc/subscriptions.go index 628ea2d342..8323206358 100644 --- a/rpc/subscriptions.go +++ b/rpc/subscriptions.go @@ -3,6 +3,8 @@ package rpc import ( "context" "encoding/json" + "errors" + "fmt" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" @@ -38,6 +40,10 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys return nil, ErrTooManyKeysInFilter } + if blockID != nil && blockID.Pending { + return nil, ErrCallOnPending + } + requestedHeader, headHeader, rpcErr := h.resolveBlockRange(blockID) if rpcErr != nil { return nil, rpcErr @@ -93,22 +99,26 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys // SubscribeTxnStatus subscribes to status changes of a transaction. It checks for updates each time a new block is added. // Subsequent updates are sent only when the transaction status changes. // The optional block_id parameter is ignored, as status changes are not stored and historical data cannot be sent. -func (h *Handler) SubscribeTxnStatus(ctx context.Context, txHash felt.Felt, _ *BlockID) (*SubscriptionID, *jsonrpc.Error) { - var ( - lastKnownStatus, lastSendStatus *TransactionStatus - wrapResult = func(s *TransactionStatus) *NewTransactionStatus { - return &NewTransactionStatus{ - TransactionHash: &txHash, - Status: s, - } - } - ) - +// +//nolint:funlen,gocyclo +func (h *Handler) SubscribeTxnStatus(ctx context.Context, txHash felt.Felt, blockID *BlockID) (*SubscriptionID, + *jsonrpc.Error, +) { w, ok := jsonrpc.ConnFromContext(ctx) if !ok { return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil) } + // resolveBlockRange is only used to make sure that the requested block id is not older than 1024 block and check + // if the requested block is found. The range is inconsequential since we assume the provided transaction hash + // of a transaction is included in the block range: latest/pending - 1024. + _, _, rpcErr := h.resolveBlockRange(blockID) + if rpcErr != nil { + return nil, rpcErr + } + + fmt.Println("Codeunder test ------ Inside subscribeTxnStatus") + id := h.idgen() subscriptionCtx, subscriptionCtxCancel := context.WithCancel(ctx) sub := &subscription{ @@ -116,59 +126,177 @@ func (h *Handler) SubscribeTxnStatus(ctx context.Context, txHash felt.Felt, _ *B conn: w, } - lastKnownStatus, rpcErr := h.TransactionStatus(subscriptionCtx, txHash) - if rpcErr != nil { - h.log.Errorw("Failed to get Tx status", "txHash", &txHash, "rpcErr", rpcErr) - return nil, rpcErr - } - h.mu.Lock() h.subscriptions[id] = sub h.mu.Unlock() - headerSub := h.newHeads.Subscribe() + l2HeadSub := h.newHeads.Subscribe() + l1HeadSub := h.l1Heads.Subscribe() + reorgSub := h.reorgs.Subscribe() + sub.wg.Go(func() { + fmt.Println("Codeunder test ------ inside big go routine") defer func() { h.unsubscribe(sub, id) - headerSub.Unsubscribe() + l2HeadSub.Unsubscribe() + l1HeadSub.Unsubscribe() + reorgSub.Unsubscribe() }() - if err := h.sendTxnStatus(sub.conn, wrapResult(lastKnownStatus), id); err != nil { - h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) - return - } - lastSendStatus = lastKnownStatus + var wg conc.WaitGroup + receipt, rpcErr := h.TransactionReceiptByHash(txHash) + + // Check if the requested transaction is already final. + // A transaction is considered to be final if it has been rejected or accepted on l1 + if rpcErr == nil { + if receipt.FinalityStatus == TxnAcceptedOnL1 { + s := &TransactionStatus{ + Finality: TxnStatus(receipt.FinalityStatus), + Execution: receipt.ExecutionStatus, + FailureReason: receipt.RevertReason, + } - for { - select { - case <-subscriptionCtx.Done(): - return - case <-headerSub.Recv(): - lastKnownStatus, rpcErr = h.TransactionStatus(subscriptionCtx, txHash) - if rpcErr != nil { - h.log.Errorw("Failed to get Tx status", "txHash", txHash, "rpcErr", rpcErr) - return + err := h.sendTxnStatus(w, SubscriptionTransactionStatus{&txHash, *s}, id) + if err != nil { + h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) } + return + } + } else if rpcErr == ErrTxnHashNotFound { + s, err := h.fetchTxnStatusFromFeeder(subscriptionCtx, txHash) + if err != nil { + h.log.Errorw("Error while fetching Txn status from feeder", "txHash", txHash, "err", err) + return + } - if *lastKnownStatus != *lastSendStatus { - if err := h.sendTxnStatus(sub.conn, wrapResult(lastKnownStatus), id); err != nil { - h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) - return - } - lastSendStatus = lastKnownStatus + if s.Finality == TxnStatusRejected { + err := h.sendTxnStatus(w, SubscriptionTransactionStatus{&txHash, *s}, id) + if err != nil { + h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) } + return + } + } else { + h.log.Errorw("Unexpected error occurred while fetching transaction receipt", "txHash", txHash, "err", + rpcErr.Message) + fmt.Println("Codeunder test ------ Returing because of an error") + return + } - // Stop when final status reached and notified - if isFinal(lastSendStatus) { + fmt.Println("Codeunder test ------ After initial checks") + // At this point, the transaction has not reached finality. + var curStatus *TransactionStatus + wg.Go(func() { + fmt.Println("Codeunder test ------ Inside long running go routine") + for { + select { + case <-subscriptionCtx.Done(): + fmt.Println("Codeunder test ------ Returing after context is done") return + case <-l2HeadSub.Recv(): + // A new block has been added to the DB, hence, check if transaction has reached l2 finality, + // if not, check feeder. + // We could use a separate timer to periodically check for the transaction status at feeder + // gateway, however, for the time being new l2 head update is sufficient. + fmt.Println("Codeunder test ------ if condigion check", curStatus == nil || curStatus. + Finality < TxnStatusAcceptedOnL2) + if curStatus == nil || curStatus.Finality < TxnStatusAcceptedOnL2 { + fmt.Println("Codeunder test ------ Inside l2Head if statement") + receipt, rpcErr := h.TransactionReceiptByHash(txHash) + if rpcErr == nil { + fmt.Println("Codeunder test ------ Inside rpcErr == nil") + fmt.Println("Codeunder test ------ Is curStatus nil", curStatus == nil, txHash.String()) + curStatus = &TransactionStatus{ + Finality: TxnStatus(receipt.FinalityStatus), + Execution: receipt.ExecutionStatus, + FailureReason: receipt.RevertReason, + } + + fmt.Println(curStatus.Finality) + + err := h.sendTxnStatus(w, SubscriptionTransactionStatus{&txHash, *curStatus}, id) + if err != nil { + h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) + return + } + if curStatus.Finality == TxnStatusAcceptedOnL1 { + return + } + fmt.Println("Codeunder test ------ about to exit rpcErr == nil") + } else if rpcErr == ErrTxnHashNotFound { + feederTxStatus, err := h.fetchTxnStatusFromFeeder(subscriptionCtx, txHash) + if err != nil { + h.log.Errorw("Error while fetching Txn status from feeder", "txHash", txHash, "err", err) + return + } + + if feederTxStatus.Finality == TxnStatusRejected { + err := h.sendTxnStatus(w, SubscriptionTransactionStatus{&txHash, *feederTxStatus}, id) + if err != nil { + h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) + } + return + } + + if feederTxStatus.Finality > curStatus.Finality { + curStatus = feederTxStatus + + err := h.sendTxnStatus(w, SubscriptionTransactionStatus{&txHash, *curStatus}, id) + if err != nil { + h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) + return + } + } + } + } + case <-l1HeadSub.Recv(): + fmt.Println("Codeunder test ------ Returing after context is done") + receipt, rpcErr := h.TransactionReceiptByHash(txHash) + if rpcErr == nil && receipt.FinalityStatus == TxnAcceptedOnL1 { + s := &TransactionStatus{ + Finality: TxnStatus(receipt.FinalityStatus), + Execution: receipt.ExecutionStatus, + FailureReason: receipt.RevertReason, + } + + err := h.sendTxnStatus(w, SubscriptionTransactionStatus{&txHash, *s}, id) + if err != nil { + h.log.Errorw("Error while sending Txn status", "txHash", txHash, "err", err) + } + return + } } } - } + }) + + wg.Go(func() { + h.processReorgs(subscriptionCtx, reorgSub, w, id) + }) }) + fmt.Println("Codeunder test ------ Just before returning") + return &SubscriptionID{ID: id}, nil } +func (h *Handler) fetchTxnStatusFromFeeder(ctx context.Context, txHash felt.Felt) (*TransactionStatus, error) { + if h.feederClient == nil { + return nil, errors.New("feedClient is nil") + } + + txStatus, err := h.feederClient.Transaction(ctx, &txHash) + if err != nil { + return nil, err + } + + status, err := adaptTransactionStatus(txStatus) + if err != nil { + return nil, err + } + + return status, nil +} + func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, to uint64, fromAddr *felt.Felt, keys [][]felt.Felt) { filter, err := h.bcReader.EventFilter(fromAddr, keys) if err != nil { @@ -255,6 +383,10 @@ func (h *Handler) SubscribeNewHeads(ctx context.Context, blockID *BlockID) (*Sub return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil) } + if blockID != nil && blockID.Pending { + return nil, ErrCallOnPending + } + startHeader, latestHeader, rpcErr := h.resolveBlockRange(blockID) if rpcErr != nil { return nil, rpcErr @@ -443,10 +575,6 @@ func (h *Handler) resolveBlockRange(blockID *BlockID) (*core.Header, *core.Heade return latestHeader, latestHeader, nil } - if blockID.Pending { - return nil, nil, ErrCallOnPending - } - startHeader, rpcErr := h.blockHeaderByID(blockID) if rpcErr != nil { return nil, nil, rpcErr @@ -581,13 +709,13 @@ func (h *Handler) Unsubscribe(ctx context.Context, id uint64) (bool, *jsonrpc.Er return true, nil } -type NewTransactionStatus struct { - TransactionHash *felt.Felt `json:"transaction_hash"` - Status *TransactionStatus `json:"status"` +type SubscriptionTransactionStatus struct { + TransactionHash *felt.Felt `json:"transaction_hash"` + Status TransactionStatus `json:"status"` } // sendTxnStatus creates a response and sends it to the client -func (h *Handler) sendTxnStatus(w jsonrpc.Conn, status *NewTransactionStatus, id uint64) error { +func (h *Handler) sendTxnStatus(w jsonrpc.Conn, status SubscriptionTransactionStatus, id uint64) error { resp, err := json.Marshal(SubscriptionResponse{ Version: "2.0", Method: "starknet_subscriptionTransactionsStatus", @@ -603,7 +731,3 @@ func (h *Handler) sendTxnStatus(w jsonrpc.Conn, status *NewTransactionStatus, id _, err = w.Write(resp) return err } - -func isFinal(status *TransactionStatus) bool { - return status.Finality == TxnStatusRejected || status.Finality == TxnStatusAcceptedOnL1 -} diff --git a/rpc/subscriptions_test.go b/rpc/subscriptions_test.go index b577b004b4..7b972ea4b9 100644 --- a/rpc/subscriptions_test.go +++ b/rpc/subscriptions_test.go @@ -103,8 +103,6 @@ func TestSubscribeEvents(t *testing.T) { subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: 1}, nil) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) assert.Zero(t, id) assert.Equal(t, ErrCallOnPending, rpcErr) @@ -389,8 +387,6 @@ func TestSubscribeNewHeads(t *testing.T) { mockSyncer := mocks.NewMockSyncReader(mockCtrl) handler := New(mockChain, mockSyncer, nil, "", log) - mockChain.EXPECT().HeadsHeader().Return(&core.Header{}, nil) - serverConn, _ := net.Pipe() t.Cleanup(func() { require.NoError(t, serverConn.Close()) @@ -448,10 +444,12 @@ func TestSubscribeNewHeads(t *testing.T) { mockChain := mocks.NewMockReader(mockCtrl) syncer := newFakeSyncer() - handler, server := setupRPC(t, ctx, mockChain, syncer) + l1Feed := feed.New[*core.L1Head]() mockChain.EXPECT().HeadsHeader().Return(&core.Header{}, nil) + mockChain.EXPECT().SubscribeL1Head().Return(blockchain.L1HeadSubscription{Subscription: l1Feed.Subscribe()}) + handler, server := setupRPC(t, ctx, mockChain, syncer) conn := createWsConn(t, ctx, server) id := uint64(1) @@ -530,6 +528,10 @@ func TestMultipleSubscribeNewHeadsAndUnsubscribe(t *testing.T) { mockChain := mocks.NewMockReader(mockCtrl) syncer := newFakeSyncer() + + l1Feed := feed.New[*core.L1Head]() + mockChain.EXPECT().SubscribeL1Head().Return(blockchain.L1HeadSubscription{Subscription: l1Feed.Subscribe()}) + handler, server := setupRPC(t, ctx, mockChain, syncer) mockChain.EXPECT().HeadsHeader().Return(&core.Header{}, nil).Times(2) @@ -594,6 +596,9 @@ func TestSubscriptionReorg(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockChain := mocks.NewMockReader(mockCtrl) + l1Feed := feed.New[*core.L1Head]() + mockChain.EXPECT().SubscribeL1Head().Return(blockchain.L1HeadSubscription{Subscription: l1Feed.Subscribe()}) + syncer := newFakeSyncer() handler, server := setupRPC(t, ctx, mockChain, syncer) @@ -658,6 +663,9 @@ func TestSubscribePendingTxs(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockChain := mocks.NewMockReader(mockCtrl) + l1Feed := feed.New[*core.L1Head]() + mockChain.EXPECT().SubscribeL1Head().Return(blockchain.L1HeadSubscription{Subscription: l1Feed.Subscribe()}) + syncer := newFakeSyncer() handler, server := setupRPC(t, ctx, mockChain, syncer) @@ -852,6 +860,10 @@ func TestSubscribeTxStatusAndUnsubscribe(t *testing.T) { t.Cleanup(mockCtrl.Finish) mockReader := mocks.NewMockReader(mockCtrl) + l1Feed := feed.New[*core.L1Head]() + mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil).AnyTimes() + mockReader.EXPECT().SubscribeL1Head().Return(blockchain.L1HeadSubscription{Subscription: l1Feed.Subscribe()}) + handler, syncer, server := setupSubscriptionTest(t, ctx, mockReader) require.NoError(t, server.RegisterMethods(jsonrpc.Method{ @@ -882,40 +894,28 @@ func TestSubscribeTxStatusAndUnsubscribe(t *testing.T) { secondID := uint64(2) t.Run("simple subscribe and unsubscribe", func(t *testing.T) { - conn1, resp1, err := websocket.Dial(ctx, httpSrv.URL, nil) - require.NoError(t, err) - defer bodyCloser(t, resp1) - - conn2, resp2, err := websocket.Dial(ctx, httpSrv.URL, nil) - require.NoError(t, err) - defer bodyCloser(t, resp2) - - handler.WithIDGen(func() uint64 { return firstID }) - firstWant := txStatusNotFoundResponse - // Notice we subscribe for non-existing tx, we expect automatic unsubscribe - firstGot := sendAndReceiveMessage(t, ctx, conn1, fmt.Sprintf(subscribeTxStatus, felt.Zero.String())) + conn, resp, err := websocket.Dial(ctx, httpSrv.URL, nil) require.NoError(t, err) - require.Equal(t, firstWant, firstGot) + defer bodyCloser(t, resp) handler.WithIDGen(func() uint64 { return secondID }) secondWant := fmt.Sprintf(subscribeResponse, secondID) - secondGot := sendAndReceiveMessage(t, ctx, conn2, fmt.Sprintf(subscribeTxStatus, txnHash)) + secondGot := sendAndReceiveMessage(t, ctx, conn, fmt.Sprintf(subscribeTxStatus, txnHash)) require.NoError(t, err) require.Equal(t, secondWant, secondGot) - // as expected the subscription is gone - firstUnsubGot := sendAndReceiveMessage(t, ctx, conn1, fmt.Sprintf(unsubscribeMsg, firstID)) - require.Equal(t, unsubscribeNotFoundResponse, firstUnsubGot) - // Receive a block header. secondWant = formatTxStatusResponse(t, txnHash, TxnStatusAcceptedOnL2, TxnSuccess, secondID) - _, secondHeaderGot, err := conn2.Read(ctx) + time.Sleep(50 * time.Millisecond) + fmt.Println("Before reading from conn") + _, secondHeaderGot, err := conn.Read(ctx) + fmt.Println("After reading from conn") secondGot = string(secondHeaderGot) require.NoError(t, err) require.Equal(t, secondWant, secondGot) // Unsubscribe - require.NoError(t, conn2.Write(ctx, websocket.MessageBinary, []byte(fmt.Sprintf(unsubscribeMsg, secondID)))) + require.NoError(t, conn.Write(ctx, websocket.MessageBinary, []byte(fmt.Sprintf(unsubscribeMsg, secondID)))) }) t.Run("no update is sent when status has not changed", func(t *testing.T) { diff --git a/rpc/transaction.go b/rpc/transaction.go index f9416c8d57..358b7d23b3 100644 --- a/rpc/transaction.go +++ b/rpc/transaction.go @@ -72,10 +72,10 @@ func (t *TransactionType) UnmarshalJSON(data []byte) error { type TxnStatus uint8 const ( - TxnStatusAcceptedOnL1 TxnStatus = iota + 1 - TxnStatusAcceptedOnL2 - TxnStatusReceived + TxnStatusReceived TxnStatus = iota + 1 TxnStatusRejected + TxnStatusAcceptedOnL2 + TxnStatusAcceptedOnL1 ) func (s TxnStatus) MarshalText() ([]byte, error) { @@ -114,8 +114,8 @@ func (es TxnExecutionStatus) MarshalText() ([]byte, error) { type TxnFinalityStatus uint8 const ( - TxnAcceptedOnL1 TxnFinalityStatus = iota + 1 - TxnAcceptedOnL2 + TxnAcceptedOnL2 TxnFinalityStatus = iota + 3 + TxnAcceptedOnL1 ) func (fs TxnFinalityStatus) MarshalText() ([]byte, error) {