From 0cf2454dc45587f6fe5370aa24ec727331858707 Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 8 Jan 2025 14:52:31 +0200 Subject: [PATCH] add nonce validation + tests --- mempool/init_test.go | 5 ++++ mempool/mempool.go | 53 +++++++++++++++++++++++++++++++++----- mempool/mempool_test.go | 56 +++++++++++++++++++++++++++++++---------- 3 files changed, 95 insertions(+), 19 deletions(-) create mode 100644 mempool/init_test.go diff --git a/mempool/init_test.go b/mempool/init_test.go new file mode 100644 index 0000000000..3104ea09e7 --- /dev/null +++ b/mempool/init_test.go @@ -0,0 +1,5 @@ +package mempool_test + +import ( + _ "github.com/NethermindEth/juno/encoder/registry" +) diff --git a/mempool/mempool.go b/mempool/mempool.go index 1eeb492823..db243f1f10 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -40,7 +40,8 @@ type txnList struct { // Pool stores the transactions in a linked list for its inherent FCFS behaviour type Pool struct { - db db.DB + state core.StateReader + db db.DB // persistent mempool txPushed chan struct{} txnList *txnList // in-memory maxNumTxns uint16 @@ -50,8 +51,9 @@ type Pool struct { // New initializes the Pool and starts the database writer goroutine. // It is the responsibility of the user to call the cancel function if the context is cancelled -func New(db db.DB, maxNumTxns uint16) (*Pool, func() error, error) { +func New(db db.DB, state core.StateReader, maxNumTxns uint16) (*Pool, func() error, error) { pool := &Pool{ + state: state, db: db, // todo: txns should be deleted everytime a new block is stored (builder responsibility) txPushed: make(chan struct{}, 1), txnList: &txnList{}, @@ -195,12 +197,11 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { // Push queues a transaction to the pool and adds it to both the in-memory list and DB func (p *Pool) Push(userTxn *BroadcastedTransaction) error { - if p.txnList.len >= uint16(p.maxNumTxns) { - return ErrTxnPoolFull + err := p.validate(userTxn) + if err != nil { + return err } - // todo(rian this PR): validation - // todo: should db overloading block the in-memory mempool?? select { case p.dbWriteChan <- userTxn: @@ -236,6 +237,44 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return nil } +func (p *Pool) validate(userTxn *BroadcastedTransaction) error { + if p.txnList.len+1 >= uint16(p.maxNumTxns) { + return ErrTxnPoolFull + } + + switch t := userTxn.Transaction.(type) { + case *core.DeployTransaction: + return fmt.Errorf("deploy transactions are not supported") + case *core.DeployAccountTransaction: + if !t.Nonce.IsZero() { + return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce) + } + case *core.DeclareTransaction: + nonce, err := p.state.ContractNonce(t.SenderAddress) + if err != nil { + return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err) + } + if nonce.Cmp(t.Nonce) > 0 { + return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) + } + case *core.InvokeTransaction: + if t.TxVersion().Is(0) { // cant verify nonce since SenderAddress was only added in v1 + return fmt.Errorf("invoke v0 transactions not supported") + } + nonce, err := p.state.ContractNonce(t.SenderAddress) + if err != nil { + return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err) + } + if nonce.Cmp(t.Nonce) > 0 { + return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) + } + case *core.L1HandlerTransaction: + // todo: verification of the L1 handler nonce requires checking the + // message nonce on the L1 Core Contract. + } + return nil +} + // Pop returns the transaction with the highest priority from the in-memory pool func (p *Pool) Pop() (BroadcastedTransaction, error) { p.txnList.mu.Lock() @@ -256,6 +295,8 @@ func (p *Pool) Pop() (BroadcastedTransaction, error) { } // Remove removes a set of transactions from the pool +// todo: should be called by the builder to remove txns from the db everytime a new block is stored. +// todo: in the consensus+p2p world, the txns should also be removed from the in-memory pool. func (p *Pool) Remove(hash ...*felt.Felt) error { return errors.New("not implemented") } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 5838a1f547..fd8801faeb 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -5,14 +5,15 @@ import ( "testing" "time" - "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/mempool" + "github.com/NethermindEth/juno/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) func setupDatabase(dltExisting bool) (*db.DB, func(), error) { @@ -39,12 +40,14 @@ func setupDatabase(dltExisting bool) (*db.DB, func(), error) { func TestMempool(t *testing.T) { testDB, dbCloser, err := setupDatabase(true) + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) require.NoError(t, err) defer dbCloser() - pool, closer, err := mempool.New(*testDB, 5) + pool, closer, err := mempool.New(*testDB, state, 4) defer closer() require.NoError(t, err) - blockchain.RegisterCoreTypesToEncoder() l := pool.Len() assert.Equal(t, uint16(0), l) @@ -54,16 +57,19 @@ func TestMempool(t *testing.T) { // push multiple to empty (1,2,3) for i := uint64(1); i < 4; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: senderAddress, + Version: new(core.TransactionVersion).SetUint64(1), }, })) - l := pool.Len() assert.Equal(t, uint16(i), l) } - // consume some (remove 1,2, keep 3) for i := uint64(1); i < 3; i++ { txn, err := pool.Pop() @@ -76,12 +82,16 @@ func TestMempool(t *testing.T) { // push multiple to non empty (push 4,5. now have 3,4,5) for i := uint64(4); i < 6; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: senderAddress, + Version: new(core.TransactionVersion).SetUint64(1), }, })) - l := pool.Len() assert.Equal(t, uint16(i-2), l) } @@ -103,15 +113,16 @@ func TestMempool(t *testing.T) { _, err = pool.Pop() require.Equal(t, err.Error(), "transaction pool is empty") - } func TestRestoreMempool(t *testing.T) { - blockchain.RegisterCoreTypesToEncoder() - + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) testDB, _, err := setupDatabase(true) require.NoError(t, err) - pool, closer, err := mempool.New(*testDB, 1024) + + pool, closer, err := mempool.New(*testDB, state, 1024) require.NoError(t, err) // Check both pools are empty @@ -122,9 +133,14 @@ func TestRestoreMempool(t *testing.T) { // push multiple transactions to empty mempool (1,2,3) for i := uint64(1); i < 4; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), + Version: new(core.TransactionVersion).SetUint64(1), + SenderAddress: senderAddress, + Nonce: new(felt.Felt).SetUint64(0), }, })) assert.Equal(t, uint16(i), pool.Len()) @@ -142,7 +158,7 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) defer dbCloser() - poolRestored, closer2, err := mempool.New(*testDB, 1024) + poolRestored, closer2, err := mempool.New(*testDB, state, 1024) require.NoError(t, err) lenDB, err = poolRestored.LenDB() require.NoError(t, err) @@ -163,9 +179,11 @@ func TestRestoreMempool(t *testing.T) { func TestWait(t *testing.T) { testDB := pebble.NewMemTest(t) - pool, _, err := mempool.New(testDB, 1024) + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) + pool, _, err := mempool.New(testDB, state, 1024) require.NoError(t, err) - blockchain.RegisterCoreTypesToEncoder() select { case <-pool.Wait(): @@ -174,22 +192,34 @@ func TestWait(t *testing.T) { } // One transaction. + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(1)).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(1), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(1), + Version: new(core.TransactionVersion).SetUint64(1), }, })) <-pool.Wait() // Two transactions. + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(2)).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(2), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(2), + Version: new(core.TransactionVersion).SetUint64(1), }, })) + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(3)).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(3), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(3), + Version: new(core.TransactionVersion).SetUint64(1), }, })) <-pool.Wait()