From 2417b460acf5d993751cd1f4d5cb78f74d578814 Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 19 Dec 2024 10:45:55 +0200 Subject: [PATCH 01/22] implement a mempool for the sequencer --- mempool/mempool.go | 231 ++++++++++++++++++++++++++++++++++++++++ mempool/mempool_test.go | 120 +++++++++++++++++++++ 2 files changed, 351 insertions(+) create mode 100644 mempool/mempool.go create mode 100644 mempool/mempool_test.go diff --git a/mempool/mempool.go b/mempool/mempool.go new file mode 100644 index 0000000000..24ca8f85a6 --- /dev/null +++ b/mempool/mempool.go @@ -0,0 +1,231 @@ +package mempool + +import ( + "encoding/binary" + "errors" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/encoder" +) + +type ValidatorFunc func(*BroadcastedTransaction) error + +type BroadcastedTransaction struct { + Transaction core.Transaction + DeclaredClass core.Class +} + +const ( + poolLengthKey = "poolLength" + headKey = "headKey" + tailKey = "tailKey" +) + +// Pool stores the transactions in a linked list for its inherent FCFS behaviour +type storageElem struct { + Txn BroadcastedTransaction + NextHash *felt.Felt +} + +type Pool struct { + db db.DB + validator ValidatorFunc + txPushed chan struct{} +} + +func New(poolDB db.DB) *Pool { + return &Pool{ + db: poolDB, + validator: func(_ *BroadcastedTransaction) error { return nil }, + txPushed: make(chan struct{}, 1), + } +} + +// WithValidator adds a validation step to be triggered before adding a +// BroadcastedTransaction to the pool +func (p *Pool) WithValidator(validator ValidatorFunc) *Pool { + p.validator = validator + return p +} + +// Push queues a transaction to the pool +func (p *Pool) Push(userTxn *BroadcastedTransaction) error { + err := p.validator(userTxn) + if err != nil { + return err + } + + if err := p.db.Update(func(txn db.Transaction) error { + tail, err := p.tailHash(txn) + if err != nil && !errors.Is(err, db.ErrKeyNotFound) { + return err + } + + if err = p.putElem(txn, userTxn.Transaction.Hash(), &storageElem{ + Txn: *userTxn, + }); err != nil { + return err + } + + if tail != nil { + var oldTail storageElem + oldTail, err = p.elem(txn, tail) + if err != nil { + return err + } + + // update old tail to point to the new item + oldTail.NextHash = userTxn.Transaction.Hash() + if err = p.putElem(txn, tail, &oldTail); err != nil { + return err + } + } else { + // empty list, make new item both the head and the tail + if err = p.updateHead(txn, userTxn.Transaction.Hash()); err != nil { + return err + } + } + + if err = p.updateTail(txn, userTxn.Transaction.Hash()); err != nil { + return err + } + + pLen, err := p.len(txn) + if err != nil { + return err + } + return p.updateLen(txn, pLen+1) // don't worry about overflows, highly unlikely + }); err != nil { + return err + } + + select { + case p.txPushed <- struct{}{}: + default: + } + + return nil +} + +// Pop returns the transaction with the highest priority from the pool +func (p *Pool) Pop() (BroadcastedTransaction, error) { + var nextTxn BroadcastedTransaction + return nextTxn, p.db.Update(func(txn db.Transaction) error { + headHash, err := p.headHash(txn) + if err != nil { + return err + } + + headElem, err := p.elem(txn, headHash) + if err != nil { + return err + } + + if err = txn.Delete(headHash.Marshal()); err != nil { + return err + } + + if headElem.NextHash == nil { + // the list is empty now + if err = txn.Delete([]byte(headKey)); err != nil { + return err + } + if err = txn.Delete([]byte(tailKey)); err != nil { + return err + } + } else { + if err = p.updateHead(txn, headElem.NextHash); err != nil { + return err + } + } + + pLen, err := p.len(txn) + if err != nil { + return err + } + + if err = p.updateLen(txn, pLen-1); err != nil { + return err + } + nextTxn = headElem.Txn + return nil + }) +} + +// Remove removes a set of transactions from the pool +func (p *Pool) Remove(hash ...*felt.Felt) error { + return errors.New("not implemented") +} + +// Len returns the number of transactions in the pool +func (p *Pool) Len() (uint64, error) { + var l uint64 + return l, p.db.View(func(txn db.Transaction) error { + var err error + l, err = p.len(txn) + return err + }) +} + +func (p *Pool) Wait() <-chan struct{} { + return p.txPushed +} + +func (p *Pool) len(txn db.Transaction) (uint64, error) { + var l uint64 + err := txn.Get([]byte(poolLengthKey), func(b []byte) error { + l = binary.BigEndian.Uint64(b) + return nil + }) + + if err != nil && errors.Is(err, db.ErrKeyNotFound) { + return 0, nil + } + return l, err +} + +func (p *Pool) updateLen(txn db.Transaction, l uint64) error { + return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint64(nil, l)) +} + +func (p *Pool) headHash(txn db.Transaction) (*felt.Felt, error) { + var head *felt.Felt + return head, txn.Get([]byte(headKey), func(b []byte) error { + head = new(felt.Felt).SetBytes(b) + return nil + }) +} + +func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { + return txn.Set([]byte(headKey), head.Marshal()) +} + +func (p *Pool) tailHash(txn db.Transaction) (*felt.Felt, error) { + var tail *felt.Felt + return tail, txn.Get([]byte(tailKey), func(b []byte) error { + tail = new(felt.Felt).SetBytes(b) + return nil + }) +} + +func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { + return txn.Set([]byte(tailKey), tail.Marshal()) +} + +func (p *Pool) elem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { + var item storageElem + err := txn.Get(itemKey.Marshal(), func(b []byte) error { + return encoder.Unmarshal(b, &item) + }) + return item, err +} + +func (p *Pool) putElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem) error { + itemBytes, err := encoder.Marshal(item) + if err != nil { + return err + } + return txn.Set(itemKey.Marshal(), itemBytes) +} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go new file mode 100644 index 0000000000..68408be459 --- /dev/null +++ b/mempool/mempool_test.go @@ -0,0 +1,120 @@ +package mempool_test + +import ( + "errors" + "testing" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/mempool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMempool(t *testing.T) { + testDB := pebble.NewMemTest(t) + pool := mempool.New(testDB) + blockchain.RegisterCoreTypesToEncoder() + + t.Run("empty pool", func(t *testing.T) { + l, err := pool.Len() + require.NoError(t, err) + assert.Equal(t, uint64(0), l) + + _, err = pool.Pop() + require.ErrorIs(t, err, db.ErrKeyNotFound) + }) + + // push multiple to empty + for i := uint64(0); i < 3; i++ { + assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(i), + }, + })) + + l, err := pool.Len() + require.NoError(t, err) + assert.Equal(t, i+1, l) + } + + // consume some + for i := uint64(0); i < 2; i++ { + txn, err := pool.Pop() + require.NoError(t, err) + assert.Equal(t, i, txn.Transaction.Hash().Uint64()) + + l, err := pool.Len() + require.NoError(t, err) + assert.Equal(t, 3-i-1, l) + } + + // push multiple to non empty + for i := uint64(3); i < 5; i++ { + assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(i), + }, + })) + + l, err := pool.Len() + require.NoError(t, err) + assert.Equal(t, i-1, l) + } + + // consume all + for i := uint64(2); i < 5; i++ { + txn, err := pool.Pop() + require.NoError(t, err) + assert.Equal(t, i, txn.Transaction.Hash().Uint64()) + } + + l, err := pool.Len() + require.NoError(t, err) + assert.Equal(t, uint64(0), l) + + _, err = pool.Pop() + require.ErrorIs(t, err, db.ErrKeyNotFound) + + // validation error + pool = pool.WithValidator(func(bt *mempool.BroadcastedTransaction) error { + return errors.New("some error") + }) + require.EqualError(t, pool.Push(&mempool.BroadcastedTransaction{}), "some error") +} + +func TestWait(t *testing.T) { + testDB := pebble.NewMemTest(t) + pool := mempool.New(testDB) + blockchain.RegisterCoreTypesToEncoder() + + select { + case <-pool.Wait(): + require.Fail(t, "wait channel should not be signalled on empty mempool") + default: + } + + // One transaction. + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt), + }, + })) + <-pool.Wait() + + // Two transactions. + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt), + }, + })) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt), + }, + })) + <-pool.Wait() +} From 9de880fcf0528664fcd39135d8f878bd6630a297 Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 19 Dec 2024 12:24:44 +0200 Subject: [PATCH 02/22] reject duplicate txns --- mempool/mempool.go | 24 +++++++++++++++++++++++- mempool/mempool_test.go | 15 ++++++++++++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 24ca8f85a6..6367fb0fb1 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -3,6 +3,7 @@ package mempool import ( "encoding/binary" "errors" + "fmt" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" @@ -50,6 +51,22 @@ func (p *Pool) WithValidator(validator ValidatorFunc) *Pool { return p } +func (p *Pool) rejectDuplicateTxn(userTxn *BroadcastedTransaction) error { + txHash := userTxn.Transaction.Hash().Marshal() + err := p.db.View(func(txn db.Transaction) error { + return txn.Get(txHash, func(val []byte) error { + if val != nil { + return fmt.Errorf("transaction already exists in the mempool: %x", txHash) + } + return nil + }) + }) + if errors.Is(err, db.ErrKeyNotFound) { + return nil + } + return err +} + // Push queues a transaction to the pool func (p *Pool) Push(userTxn *BroadcastedTransaction) error { err := p.validator(userTxn) @@ -57,6 +74,11 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return err } + err = p.rejectDuplicateTxn(userTxn) + if err != nil { + return err + } + if err := p.db.Update(func(txn db.Transaction) error { tail, err := p.tailHash(txn) if err != nil && !errors.Is(err, db.ErrKeyNotFound) { @@ -68,7 +90,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { }); err != nil { return err } - + fmt.Println("tail", tail) if tail != nil { var oldTail storageElem oldTail, err = p.elem(txn, tail) diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 68408be459..da7e96afe4 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -79,6 +79,15 @@ func TestMempool(t *testing.T) { _, err = pool.Pop() require.ErrorIs(t, err, db.ErrKeyNotFound) + // reject duplicate txn + txn := mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(2), + }, + } + require.NoError(t, pool.Push(&txn)) + require.Error(t, pool.Push(&txn)) + // validation error pool = pool.WithValidator(func(bt *mempool.BroadcastedTransaction) error { return errors.New("some error") @@ -100,7 +109,7 @@ func TestWait(t *testing.T) { // One transaction. require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ - TransactionHash: new(felt.Felt), + TransactionHash: new(felt.Felt).SetUint64(1), }, })) <-pool.Wait() @@ -108,12 +117,12 @@ func TestWait(t *testing.T) { // Two transactions. require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ - TransactionHash: new(felt.Felt), + TransactionHash: new(felt.Felt).SetUint64(2), }, })) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ - TransactionHash: new(felt.Felt), + TransactionHash: new(felt.Felt).SetUint64(3), }, })) <-pool.Wait() From 61af100000766b420152a52898561c011942b716 Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 19 Dec 2024 16:44:17 +0200 Subject: [PATCH 03/22] some heap optimisations --- mempool/mempool.go | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 6367fb0fb1..04d5fcca3e 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -3,7 +3,6 @@ package mempool import ( "encoding/binary" "errors" - "fmt" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" @@ -51,22 +50,6 @@ func (p *Pool) WithValidator(validator ValidatorFunc) *Pool { return p } -func (p *Pool) rejectDuplicateTxn(userTxn *BroadcastedTransaction) error { - txHash := userTxn.Transaction.Hash().Marshal() - err := p.db.View(func(txn db.Transaction) error { - return txn.Get(txHash, func(val []byte) error { - if val != nil { - return fmt.Errorf("transaction already exists in the mempool: %x", txHash) - } - return nil - }) - }) - if errors.Is(err, db.ErrKeyNotFound) { - return nil - } - return err -} - // Push queues a transaction to the pool func (p *Pool) Push(userTxn *BroadcastedTransaction) error { err := p.validator(userTxn) @@ -74,13 +57,9 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return err } - err = p.rejectDuplicateTxn(userTxn) - if err != nil { - return err - } - if err := p.db.Update(func(txn db.Transaction) error { - tail, err := p.tailHash(txn) + var tail *felt.Felt + tail, err := p.tailHash(txn, tail) if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return err } @@ -90,7 +69,6 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { }); err != nil { return err } - fmt.Println("tail", tail) if tail != nil { var oldTail storageElem oldTail, err = p.elem(txn, tail) @@ -135,7 +113,8 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { func (p *Pool) Pop() (BroadcastedTransaction, error) { var nextTxn BroadcastedTransaction return nextTxn, p.db.Update(func(txn db.Transaction) error { - headHash, err := p.headHash(txn) + var headHash *felt.Felt + headHash, err := p.headHash(txn, headHash) if err != nil { return err } @@ -212,10 +191,9 @@ func (p *Pool) updateLen(txn db.Transaction, l uint64) error { return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint64(nil, l)) } -func (p *Pool) headHash(txn db.Transaction) (*felt.Felt, error) { - var head *felt.Felt - return head, txn.Get([]byte(headKey), func(b []byte) error { - head = new(felt.Felt).SetBytes(b) +func (p *Pool) headHash(txn db.Transaction, headHash *felt.Felt) (*felt.Felt, error) { + return headHash, txn.Get([]byte(headKey), func(b []byte) error { + headHash = headHash.SetBytes(b) return nil }) } @@ -224,10 +202,9 @@ func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { return txn.Set([]byte(headKey), head.Marshal()) } -func (p *Pool) tailHash(txn db.Transaction) (*felt.Felt, error) { - var tail *felt.Felt +func (p *Pool) tailHash(txn db.Transaction, tail *felt.Felt) (*felt.Felt, error) { return tail, txn.Get([]byte(tailKey), func(b []byte) error { - tail = new(felt.Felt).SetBytes(b) + tail = tail.SetBytes(b) return nil }) } From d2ee2cc8296d153cdf13ff9d9965de17e290ee37 Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 19 Dec 2024 16:50:34 +0200 Subject: [PATCH 04/22] Revert "some heap optimisations" This reverts commit 0518efb4c39cd62fbc67f60037ac22a63932eed2. --- mempool/mempool.go | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 04d5fcca3e..6367fb0fb1 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -3,6 +3,7 @@ package mempool import ( "encoding/binary" "errors" + "fmt" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" @@ -50,6 +51,22 @@ func (p *Pool) WithValidator(validator ValidatorFunc) *Pool { return p } +func (p *Pool) rejectDuplicateTxn(userTxn *BroadcastedTransaction) error { + txHash := userTxn.Transaction.Hash().Marshal() + err := p.db.View(func(txn db.Transaction) error { + return txn.Get(txHash, func(val []byte) error { + if val != nil { + return fmt.Errorf("transaction already exists in the mempool: %x", txHash) + } + return nil + }) + }) + if errors.Is(err, db.ErrKeyNotFound) { + return nil + } + return err +} + // Push queues a transaction to the pool func (p *Pool) Push(userTxn *BroadcastedTransaction) error { err := p.validator(userTxn) @@ -57,9 +74,13 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return err } + err = p.rejectDuplicateTxn(userTxn) + if err != nil { + return err + } + if err := p.db.Update(func(txn db.Transaction) error { - var tail *felt.Felt - tail, err := p.tailHash(txn, tail) + tail, err := p.tailHash(txn) if err != nil && !errors.Is(err, db.ErrKeyNotFound) { return err } @@ -69,6 +90,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { }); err != nil { return err } + fmt.Println("tail", tail) if tail != nil { var oldTail storageElem oldTail, err = p.elem(txn, tail) @@ -113,8 +135,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { func (p *Pool) Pop() (BroadcastedTransaction, error) { var nextTxn BroadcastedTransaction return nextTxn, p.db.Update(func(txn db.Transaction) error { - var headHash *felt.Felt - headHash, err := p.headHash(txn, headHash) + headHash, err := p.headHash(txn) if err != nil { return err } @@ -191,9 +212,10 @@ func (p *Pool) updateLen(txn db.Transaction, l uint64) error { return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint64(nil, l)) } -func (p *Pool) headHash(txn db.Transaction, headHash *felt.Felt) (*felt.Felt, error) { - return headHash, txn.Get([]byte(headKey), func(b []byte) error { - headHash = headHash.SetBytes(b) +func (p *Pool) headHash(txn db.Transaction) (*felt.Felt, error) { + var head *felt.Felt + return head, txn.Get([]byte(headKey), func(b []byte) error { + head = new(felt.Felt).SetBytes(b) return nil }) } @@ -202,9 +224,10 @@ func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { return txn.Set([]byte(headKey), head.Marshal()) } -func (p *Pool) tailHash(txn db.Transaction, tail *felt.Felt) (*felt.Felt, error) { +func (p *Pool) tailHash(txn db.Transaction) (*felt.Felt, error) { + var tail *felt.Felt return tail, txn.Get([]byte(tailKey), func(b []byte) error { - tail = tail.SetBytes(b) + tail = new(felt.Felt).SetBytes(b) return nil }) } From cb0ad26d47e59bb2f9e8cfce1f5d3b3e159bf3b0 Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 3 Jan 2025 14:43:32 +0200 Subject: [PATCH 05/22] comments: doc string, move rejectDup fn, rogue print --- mempool/mempool.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 6367fb0fb1..2ca38409b3 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -24,12 +24,12 @@ const ( tailKey = "tailKey" ) -// Pool stores the transactions in a linked list for its inherent FCFS behaviour type storageElem struct { Txn BroadcastedTransaction NextHash *felt.Felt } +// Pool stores the transactions in a linked list for its inherent FCFS behaviour type Pool struct { db db.DB validator ValidatorFunc @@ -51,22 +51,6 @@ func (p *Pool) WithValidator(validator ValidatorFunc) *Pool { return p } -func (p *Pool) rejectDuplicateTxn(userTxn *BroadcastedTransaction) error { - txHash := userTxn.Transaction.Hash().Marshal() - err := p.db.View(func(txn db.Transaction) error { - return txn.Get(txHash, func(val []byte) error { - if val != nil { - return fmt.Errorf("transaction already exists in the mempool: %x", txHash) - } - return nil - }) - }) - if errors.Is(err, db.ErrKeyNotFound) { - return nil - } - return err -} - // Push queues a transaction to the pool func (p *Pool) Push(userTxn *BroadcastedTransaction) error { err := p.validator(userTxn) @@ -90,7 +74,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { }); err != nil { return err } - fmt.Println("tail", tail) + if tail != nil { var oldTail storageElem oldTail, err = p.elem(txn, tail) @@ -251,3 +235,19 @@ func (p *Pool) putElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem } return txn.Set(itemKey.Marshal(), itemBytes) } + +func (p *Pool) rejectDuplicateTxn(userTxn *BroadcastedTransaction) error { + txHash := userTxn.Transaction.Hash().Marshal() + err := p.db.View(func(txn db.Transaction) error { + return txn.Get(txHash, func(val []byte) error { + if val != nil { + return fmt.Errorf("transaction already exists in the mempool: %x", txHash) + } + return nil + }) + }) + if errors.Is(err, db.ErrKeyNotFound) { + return nil + } + return err +} From d7a7dc8ff29ff189384200ea7d43e93e3d550e67 Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 3 Jan 2025 16:11:29 +0200 Subject: [PATCH 06/22] move tail to stack --- mempool/mempool.go | 17 +++++++++-------- mempool/mempool_test.go | 24 +++++++++++++----------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 2ca38409b3..ae52d19788 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -64,11 +64,13 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { } if err := p.db.Update(func(txn db.Transaction) error { - tail, err := p.tailHash(txn) - if err != nil && !errors.Is(err, db.ErrKeyNotFound) { - return err + tail := new(felt.Felt) + if err := p.tailHash(txn, tail); err != nil { + if !errors.Is(err, db.ErrKeyNotFound) { + return err + } + tail = nil } - if err = p.putElem(txn, userTxn.Transaction.Hash(), &storageElem{ Txn: *userTxn, }); err != nil { @@ -208,10 +210,9 @@ func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { return txn.Set([]byte(headKey), head.Marshal()) } -func (p *Pool) tailHash(txn db.Transaction) (*felt.Felt, error) { - var tail *felt.Felt - return tail, txn.Get([]byte(tailKey), func(b []byte) error { - tail = new(felt.Felt).SetBytes(b) +func (p *Pool) tailHash(txn db.Transaction, tail *felt.Felt) error { + return txn.Get([]byte(tailKey), func(b []byte) error { + tail.SetBytes(b) return nil }) } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index da7e96afe4..138213bdd7 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -2,6 +2,7 @@ package mempool_test import ( "errors" + "fmt" "testing" "github.com/NethermindEth/juno/blockchain" @@ -28,8 +29,8 @@ func TestMempool(t *testing.T) { require.ErrorIs(t, err, db.ErrKeyNotFound) }) - // push multiple to empty - for i := uint64(0); i < 3; i++ { + // push multiple to empty (1,2,3) + for i := uint64(1); i < 4; i++ { assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -38,22 +39,23 @@ func TestMempool(t *testing.T) { l, err := pool.Len() require.NoError(t, err) - assert.Equal(t, i+1, l) + assert.Equal(t, i, l) } - // consume some - for i := uint64(0); i < 2; i++ { + // consume some (remove 1,2, keep 3) + for i := uint64(1); i < 3; i++ { txn, err := pool.Pop() + fmt.Println("txn", txn.Transaction.Hash().String()) require.NoError(t, err) assert.Equal(t, i, txn.Transaction.Hash().Uint64()) l, err := pool.Len() require.NoError(t, err) - assert.Equal(t, 3-i-1, l) + assert.Equal(t, 3-i, l) } - // push multiple to non empty - for i := uint64(3); i < 5; i++ { + // push multiple to non empty (push 4,5. now have 3,4,5) + for i := uint64(4); i < 6; i++ { assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -62,11 +64,11 @@ func TestMempool(t *testing.T) { l, err := pool.Len() require.NoError(t, err) - assert.Equal(t, i-1, l) + assert.Equal(t, i-2, l) } - // consume all - for i := uint64(2); i < 5; i++ { + // consume all (remove 3,4,5) + for i := uint64(3); i < 6; i++ { txn, err := pool.Pop() require.NoError(t, err) assert.Equal(t, i, txn.Transaction.Hash().Uint64()) From ab1aae608a850cebb951651ad3ffc7e389c57431 Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 3 Jan 2025 16:15:01 +0200 Subject: [PATCH 07/22] move headHash to stack --- mempool/mempool.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index ae52d19788..d9104cad62 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -121,7 +121,8 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { func (p *Pool) Pop() (BroadcastedTransaction, error) { var nextTxn BroadcastedTransaction return nextTxn, p.db.Update(func(txn db.Transaction) error { - headHash, err := p.headHash(txn) + headHash := new(felt.Felt) + err := p.headHash(txn, headHash) if err != nil { return err } @@ -198,10 +199,9 @@ func (p *Pool) updateLen(txn db.Transaction, l uint64) error { return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint64(nil, l)) } -func (p *Pool) headHash(txn db.Transaction) (*felt.Felt, error) { - var head *felt.Felt - return head, txn.Get([]byte(headKey), func(b []byte) error { - head = new(felt.Felt).SetBytes(b) +func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { + return txn.Get([]byte(headKey), func(b []byte) error { + head.SetBytes(b) return nil }) } From f4fb79e82c9192a4b35a28659c6bb36e22af164d Mon Sep 17 00:00:00 2001 From: rian Date: Mon, 6 Jan 2025 17:48:48 +0200 Subject: [PATCH 08/22] implement in-memory mempool --- mempool/mempool.go | 342 +++++++++++++++++++++++++--------------- mempool/mempool_test.go | 143 ++++++++++++----- 2 files changed, 322 insertions(+), 163 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index d9104cad62..1eeb492823 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "sync" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" @@ -11,104 +12,222 @@ import ( "github.com/NethermindEth/juno/encoder" ) -type ValidatorFunc func(*BroadcastedTransaction) error - -type BroadcastedTransaction struct { - Transaction core.Transaction - DeclaredClass core.Class -} - const ( poolLengthKey = "poolLength" headKey = "headKey" tailKey = "tailKey" ) +var ErrTxnPoolFull = errors.New("transaction pool is full") + type storageElem struct { Txn BroadcastedTransaction - NextHash *felt.Felt + NextHash *felt.Felt // persistent db + Next *storageElem // in-memory +} + +type BroadcastedTransaction struct { + Transaction core.Transaction + DeclaredClass core.Class +} + +type txnList struct { + head *storageElem + tail *storageElem + len uint16 + mu sync.Mutex } // Pool stores the transactions in a linked list for its inherent FCFS behaviour type Pool struct { - db db.DB - validator ValidatorFunc - txPushed chan struct{} + db db.DB + txPushed chan struct{} + txnList *txnList // in-memory + maxNumTxns uint16 + dbWriteChan chan *BroadcastedTransaction + wg sync.WaitGroup } -func New(poolDB db.DB) *Pool { - return &Pool{ - db: poolDB, - validator: func(_ *BroadcastedTransaction) error { return nil }, - txPushed: make(chan struct{}, 1), +// New initializes the Pool and starts the database writer goroutine. +// It is the responsibility of the user to call the cancel function if the context is cancelled +func New(db db.DB, maxNumTxns uint16) (*Pool, func() error, error) { + pool := &Pool{ + db: db, // todo: txns should be deleted everytime a new block is stored (builder responsibility) + txPushed: make(chan struct{}, 1), + txnList: &txnList{}, + maxNumTxns: maxNumTxns, + dbWriteChan: make(chan *BroadcastedTransaction, maxNumTxns), } -} -// WithValidator adds a validation step to be triggered before adding a -// BroadcastedTransaction to the pool -func (p *Pool) WithValidator(validator ValidatorFunc) *Pool { - p.validator = validator - return p -} + if err := pool.loadFromDB(); err != nil { + return nil, nil, fmt.Errorf("failed to load transactions from database into the in-memory transaction list: %v\n", err) + } -// Push queues a transaction to the pool -func (p *Pool) Push(userTxn *BroadcastedTransaction) error { - err := p.validator(userTxn) - if err != nil { - return err + pool.wg.Add(1) + go pool.dbWriter() + closer := func() error { + close(pool.dbWriteChan) + pool.wg.Wait() + if err := pool.db.Close(); err != nil { + return fmt.Errorf("failed to close mempool database: %v", err) + } + return nil } + return pool, closer, nil +} - err = p.rejectDuplicateTxn(userTxn) - if err != nil { - return err +func (p *Pool) dbWriter() { + defer p.wg.Done() + for { + select { + case txn, ok := <-p.dbWriteChan: + if !ok { + return + } + p.handleTransaction(txn) + } } +} - if err := p.db.Update(func(txn db.Transaction) error { - tail := new(felt.Felt) - if err := p.tailHash(txn, tail); err != nil { +// loadFromDB restores the in-memory transaction pool from the database +func (p *Pool) loadFromDB() error { + return p.db.View(func(txn db.Transaction) error { + len, err := p.LenDB() + if err != nil { + return err + } + if len >= p.maxNumTxns { + return ErrTxnPoolFull + } + headValue := new(felt.Felt) + err = p.headHash(txn, headValue) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { + return nil + } + return err + } + + currentHash := headValue + for currentHash != nil { + curElem, err := p.dbElem(txn, currentHash) + if err != nil { + return err + } + + newNode := &storageElem{ + Txn: curElem.Txn, + } + + if curElem.NextHash != nil { + nxtElem, err := p.dbElem(txn, curElem.NextHash) + if err != nil { + return err + } + newNode.Next = &storageElem{ + Txn: nxtElem.Txn, + } + } + + p.txnList.mu.Lock() + if p.txnList.tail != nil { + p.txnList.tail.Next = newNode + p.txnList.tail = newNode + } else { + p.txnList.head = newNode + p.txnList.tail = newNode + } + p.txnList.len++ + p.txnList.mu.Unlock() + + currentHash = curElem.NextHash + } + + return nil + }) +} + +func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { + return p.db.Update(func(dbTxn db.Transaction) error { + tailValue := new(felt.Felt) + if err := p.tailValue(dbTxn, tailValue); err != nil { if !errors.Is(err, db.ErrKeyNotFound) { return err } - tail = nil + tailValue = nil } - if err = p.putElem(txn, userTxn.Transaction.Hash(), &storageElem{ + + if err := p.putdbElem(dbTxn, userTxn.Transaction.Hash(), &storageElem{ Txn: *userTxn, }); err != nil { return err } - if tail != nil { - var oldTail storageElem - oldTail, err = p.elem(txn, tail) + if tailValue != nil { + // Update old tail to point to the new item + var oldTailElem storageElem + oldTailElem, err := p.dbElem(dbTxn, tailValue) if err != nil { return err } - - // update old tail to point to the new item - oldTail.NextHash = userTxn.Transaction.Hash() - if err = p.putElem(txn, tail, &oldTail); err != nil { + oldTailElem.NextHash = userTxn.Transaction.Hash() + if err = p.putdbElem(dbTxn, tailValue, &oldTailElem); err != nil { return err } } else { - // empty list, make new item both the head and the tail - if err = p.updateHead(txn, userTxn.Transaction.Hash()); err != nil { + // Empty list, make new item both the head and the tail + if err := p.updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } } - if err = p.updateTail(txn, userTxn.Transaction.Hash()); err != nil { + if err := p.updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } - pLen, err := p.len(txn) + pLen, err := p.lenDB(dbTxn) if err != nil { return err } - return p.updateLen(txn, pLen+1) // don't worry about overflows, highly unlikely - }); err != nil { - return err + return p.updateLen(dbTxn, uint16(pLen+1)) + }) +} + +// Push queues a transaction to the pool and adds it to both the in-memory list and DB +func (p *Pool) Push(userTxn *BroadcastedTransaction) error { + if p.txnList.len >= uint16(p.maxNumTxns) { + return ErrTxnPoolFull } + // todo(rian this PR): validation + + // todo: should db overloading block the in-memory mempool?? + select { + case p.dbWriteChan <- userTxn: + default: + select { + case _, ok := <-p.dbWriteChan: + if !ok { + return errors.New("transaction pool database write channel is closed") + } + return ErrTxnPoolFull + default: + return ErrTxnPoolFull + } + } + + p.txnList.mu.Lock() + newNode := &storageElem{Txn: *userTxn, Next: nil} + if p.txnList.tail != nil { + p.txnList.tail.Next = newNode + p.txnList.tail = newNode + } else { + p.txnList.head = newNode + p.txnList.tail = newNode + } + p.txnList.len++ + p.txnList.mu.Unlock() + select { case p.txPushed <- struct{}{}: default: @@ -117,50 +236,23 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return nil } -// Pop returns the transaction with the highest priority from the pool +// Pop returns the transaction with the highest priority from the in-memory pool func (p *Pool) Pop() (BroadcastedTransaction, error) { - var nextTxn BroadcastedTransaction - return nextTxn, p.db.Update(func(txn db.Transaction) error { - headHash := new(felt.Felt) - err := p.headHash(txn, headHash) - if err != nil { - return err - } - - headElem, err := p.elem(txn, headHash) - if err != nil { - return err - } - - if err = txn.Delete(headHash.Marshal()); err != nil { - return err - } + p.txnList.mu.Lock() + defer p.txnList.mu.Unlock() - if headElem.NextHash == nil { - // the list is empty now - if err = txn.Delete([]byte(headKey)); err != nil { - return err - } - if err = txn.Delete([]byte(tailKey)); err != nil { - return err - } - } else { - if err = p.updateHead(txn, headElem.NextHash); err != nil { - return err - } - } + if p.txnList.head == nil { + return BroadcastedTransaction{}, errors.New("transaction pool is empty") + } - pLen, err := p.len(txn) - if err != nil { - return err - } + headNode := p.txnList.head + p.txnList.head = headNode.Next + if p.txnList.head == nil { + p.txnList.tail = nil + } + p.txnList.len-- - if err = p.updateLen(txn, pLen-1); err != nil { - return err - } - nextTxn = headElem.Txn - return nil - }) + return headNode.Txn, nil } // Remove removes a set of transactions from the pool @@ -168,24 +260,25 @@ func (p *Pool) Remove(hash ...*felt.Felt) error { return errors.New("not implemented") } -// Len returns the number of transactions in the pool -func (p *Pool) Len() (uint64, error) { - var l uint64 - return l, p.db.View(func(txn db.Transaction) error { - var err error - l, err = p.len(txn) - return err - }) +// Len returns the number of transactions in the in-memory pool +func (p *Pool) Len() uint16 { + return p.txnList.len } -func (p *Pool) Wait() <-chan struct{} { - return p.txPushed +// Len returns the number of transactions in the persistent pool +func (p *Pool) LenDB() (uint16, error) { + txn, err := p.db.NewTransaction(false) + if err != nil { + return 0, err + } + defer txn.Discard() + return p.lenDB(txn) } -func (p *Pool) len(txn db.Transaction) (uint64, error) { - var l uint64 +func (p *Pool) lenDB(txn db.Transaction) (uint16, error) { + var l uint16 err := txn.Get([]byte(poolLengthKey), func(b []byte) error { - l = binary.BigEndian.Uint64(b) + l = binary.BigEndian.Uint16(b) return nil }) @@ -195,8 +288,12 @@ func (p *Pool) len(txn db.Transaction) (uint64, error) { return l, err } -func (p *Pool) updateLen(txn db.Transaction, l uint64) error { - return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint64(nil, l)) +func (p *Pool) updateLen(txn db.Transaction, l uint16) error { + return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint16(nil, l)) +} + +func (p *Pool) Wait() <-chan struct{} { + return p.txPushed } func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { @@ -206,11 +303,24 @@ func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { }) } +func (p *Pool) HeadHash() (*felt.Felt, error) { + txn, err := p.db.NewTransaction(false) + if err != nil { + return nil, err + } + var head *felt.Felt + err = txn.Get([]byte(headKey), func(b []byte) error { + head = new(felt.Felt).SetBytes(b) + return nil + }) + return head, err +} + func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { return txn.Set([]byte(headKey), head.Marshal()) } -func (p *Pool) tailHash(txn db.Transaction, tail *felt.Felt) error { +func (p *Pool) tailValue(txn db.Transaction, tail *felt.Felt) error { return txn.Get([]byte(tailKey), func(b []byte) error { tail.SetBytes(b) return nil @@ -221,7 +331,7 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set([]byte(tailKey), tail.Marshal()) } -func (p *Pool) elem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { +func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem err := txn.Get(itemKey.Marshal(), func(b []byte) error { return encoder.Unmarshal(b, &item) @@ -229,26 +339,10 @@ func (p *Pool) elem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) return item, err } -func (p *Pool) putElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem) error { +func (p *Pool) putdbElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem) error { itemBytes, err := encoder.Marshal(item) if err != nil { return err } return txn.Set(itemKey.Marshal(), itemBytes) } - -func (p *Pool) rejectDuplicateTxn(userTxn *BroadcastedTransaction) error { - txHash := userTxn.Transaction.Hash().Marshal() - err := p.db.View(func(txn db.Transaction) error { - return txn.Get(txHash, func(val []byte) error { - if val != nil { - return fmt.Errorf("transaction already exists in the mempool: %x", txHash) - } - return nil - }) - }) - if errors.Is(err, db.ErrKeyNotFound) { - return nil - } - return err -} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 138213bdd7..5838a1f547 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -1,9 +1,9 @@ package mempool_test import ( - "errors" - "fmt" + "os" "testing" + "time" "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" @@ -15,19 +15,42 @@ import ( "github.com/stretchr/testify/require" ) +func setupDatabase(dltExisting bool) (*db.DB, func(), error) { + dbPath := "testmempool" + if _, err := os.Stat(dbPath); err == nil { + if dltExisting { + if err := os.RemoveAll(dbPath); err != nil { + return nil, nil, err + } + } + } else if !os.IsNotExist(err) { + return nil, nil, err + } + db, err := pebble.New(dbPath) + if err != nil { + return nil, nil, err + } + closer := func() { + // The db should be closed by the mempool closer function + os.RemoveAll(dbPath) + } + return &db, closer, nil +} + func TestMempool(t *testing.T) { - testDB := pebble.NewMemTest(t) - pool := mempool.New(testDB) + testDB, dbCloser, err := setupDatabase(true) + require.NoError(t, err) + defer dbCloser() + pool, closer, err := mempool.New(*testDB, 5) + defer closer() + require.NoError(t, err) blockchain.RegisterCoreTypesToEncoder() - t.Run("empty pool", func(t *testing.T) { - l, err := pool.Len() - require.NoError(t, err) - assert.Equal(t, uint64(0), l) + l := pool.Len() + assert.Equal(t, uint16(0), l) - _, err = pool.Pop() - require.ErrorIs(t, err, db.ErrKeyNotFound) - }) + _, err = pool.Pop() + require.Equal(t, err.Error(), "transaction pool is empty") // push multiple to empty (1,2,3) for i := uint64(1); i < 4; i++ { @@ -37,21 +60,18 @@ func TestMempool(t *testing.T) { }, })) - l, err := pool.Len() - require.NoError(t, err) - assert.Equal(t, i, l) + l := pool.Len() + assert.Equal(t, uint16(i), l) } // consume some (remove 1,2, keep 3) for i := uint64(1); i < 3; i++ { txn, err := pool.Pop() - fmt.Println("txn", txn.Transaction.Hash().String()) require.NoError(t, err) assert.Equal(t, i, txn.Transaction.Hash().Uint64()) - l, err := pool.Len() - require.NoError(t, err) - assert.Equal(t, 3-i, l) + l := pool.Len() + assert.Equal(t, uint16(3-i), l) } // push multiple to non empty (push 4,5. now have 3,4,5) @@ -62,44 +82,89 @@ func TestMempool(t *testing.T) { }, })) - l, err := pool.Len() - require.NoError(t, err) - assert.Equal(t, i-2, l) + l := pool.Len() + assert.Equal(t, uint16(i-2), l) } + // push more than max + assert.ErrorIs(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(123), + }, + }), mempool.ErrTxnPoolFull) + // consume all (remove 3,4,5) for i := uint64(3); i < 6; i++ { txn, err := pool.Pop() require.NoError(t, err) assert.Equal(t, i, txn.Transaction.Hash().Uint64()) } + assert.Equal(t, uint16(0), l) + + _, err = pool.Pop() + require.Equal(t, err.Error(), "transaction pool is empty") - l, err := pool.Len() +} + +func TestRestoreMempool(t *testing.T) { + blockchain.RegisterCoreTypesToEncoder() + + testDB, _, err := setupDatabase(true) + require.NoError(t, err) + pool, closer, err := mempool.New(*testDB, 1024) require.NoError(t, err) - assert.Equal(t, uint64(0), l) - _, err = pool.Pop() - require.ErrorIs(t, err, db.ErrKeyNotFound) + // Check both pools are empty + lenDB, err := pool.LenDB() + require.NoError(t, err) + assert.Equal(t, uint16(0), lenDB) + assert.Equal(t, uint16(0), pool.Len()) - // reject duplicate txn - txn := mempool.BroadcastedTransaction{ - Transaction: &core.InvokeTransaction{ - TransactionHash: new(felt.Felt).SetUint64(2), - }, + // push multiple transactions to empty mempool (1,2,3) + for i := uint64(1); i < 4; i++ { + assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(i), + }, + })) + assert.Equal(t, uint16(i), pool.Len()) } - require.NoError(t, pool.Push(&txn)) - require.Error(t, pool.Push(&txn)) - - // validation error - pool = pool.WithValidator(func(bt *mempool.BroadcastedTransaction) error { - return errors.New("some error") - }) - require.EqualError(t, pool.Push(&mempool.BroadcastedTransaction{}), "some error") + + // check the db has stored the transactions + time.Sleep(100 * time.Millisecond) + lenDB, err = pool.LenDB() + require.NoError(t, err) + assert.Equal(t, uint16(3), lenDB) + + // Close the mempool + require.NoError(t, closer()) + testDB, dbCloser, err := setupDatabase(false) + require.NoError(t, err) + defer dbCloser() + + poolRestored, closer2, err := mempool.New(*testDB, 1024) + require.NoError(t, err) + lenDB, err = poolRestored.LenDB() + require.NoError(t, err) + assert.Equal(t, uint16(3), lenDB) + assert.Equal(t, uint16(3), poolRestored.Len()) + + // Remove transactions + _, err = poolRestored.Pop() + require.NoError(t, err) + _, err = poolRestored.Pop() + require.NoError(t, err) + lenDB, err = poolRestored.LenDB() + assert.Equal(t, uint16(3), lenDB) + assert.Equal(t, uint16(1), poolRestored.Len()) + + closer2() } func TestWait(t *testing.T) { testDB := pebble.NewMemTest(t) - pool := mempool.New(testDB) + pool, _, err := mempool.New(testDB, 1024) + require.NoError(t, err) blockchain.RegisterCoreTypesToEncoder() select { From fc2396865113457fdab6e2debf031541ec34f300 Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 8 Jan 2025 14:52:31 +0200 Subject: [PATCH 09/22] add nonce validation + tests --- mempool/init_test.go | 5 ++++ mempool/mempool.go | 53 +++++++++++++++++++++++++++++++++----- mempool/mempool_test.go | 56 +++++++++++++++++++++++++++++++---------- 3 files changed, 95 insertions(+), 19 deletions(-) create mode 100644 mempool/init_test.go diff --git a/mempool/init_test.go b/mempool/init_test.go new file mode 100644 index 0000000000..3104ea09e7 --- /dev/null +++ b/mempool/init_test.go @@ -0,0 +1,5 @@ +package mempool_test + +import ( + _ "github.com/NethermindEth/juno/encoder/registry" +) diff --git a/mempool/mempool.go b/mempool/mempool.go index 1eeb492823..db243f1f10 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -40,7 +40,8 @@ type txnList struct { // Pool stores the transactions in a linked list for its inherent FCFS behaviour type Pool struct { - db db.DB + state core.StateReader + db db.DB // persistent mempool txPushed chan struct{} txnList *txnList // in-memory maxNumTxns uint16 @@ -50,8 +51,9 @@ type Pool struct { // New initializes the Pool and starts the database writer goroutine. // It is the responsibility of the user to call the cancel function if the context is cancelled -func New(db db.DB, maxNumTxns uint16) (*Pool, func() error, error) { +func New(db db.DB, state core.StateReader, maxNumTxns uint16) (*Pool, func() error, error) { pool := &Pool{ + state: state, db: db, // todo: txns should be deleted everytime a new block is stored (builder responsibility) txPushed: make(chan struct{}, 1), txnList: &txnList{}, @@ -195,12 +197,11 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { // Push queues a transaction to the pool and adds it to both the in-memory list and DB func (p *Pool) Push(userTxn *BroadcastedTransaction) error { - if p.txnList.len >= uint16(p.maxNumTxns) { - return ErrTxnPoolFull + err := p.validate(userTxn) + if err != nil { + return err } - // todo(rian this PR): validation - // todo: should db overloading block the in-memory mempool?? select { case p.dbWriteChan <- userTxn: @@ -236,6 +237,44 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return nil } +func (p *Pool) validate(userTxn *BroadcastedTransaction) error { + if p.txnList.len+1 >= uint16(p.maxNumTxns) { + return ErrTxnPoolFull + } + + switch t := userTxn.Transaction.(type) { + case *core.DeployTransaction: + return fmt.Errorf("deploy transactions are not supported") + case *core.DeployAccountTransaction: + if !t.Nonce.IsZero() { + return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce) + } + case *core.DeclareTransaction: + nonce, err := p.state.ContractNonce(t.SenderAddress) + if err != nil { + return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err) + } + if nonce.Cmp(t.Nonce) > 0 { + return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) + } + case *core.InvokeTransaction: + if t.TxVersion().Is(0) { // cant verify nonce since SenderAddress was only added in v1 + return fmt.Errorf("invoke v0 transactions not supported") + } + nonce, err := p.state.ContractNonce(t.SenderAddress) + if err != nil { + return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err) + } + if nonce.Cmp(t.Nonce) > 0 { + return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) + } + case *core.L1HandlerTransaction: + // todo: verification of the L1 handler nonce requires checking the + // message nonce on the L1 Core Contract. + } + return nil +} + // Pop returns the transaction with the highest priority from the in-memory pool func (p *Pool) Pop() (BroadcastedTransaction, error) { p.txnList.mu.Lock() @@ -256,6 +295,8 @@ func (p *Pool) Pop() (BroadcastedTransaction, error) { } // Remove removes a set of transactions from the pool +// todo: should be called by the builder to remove txns from the db everytime a new block is stored. +// todo: in the consensus+p2p world, the txns should also be removed from the in-memory pool. func (p *Pool) Remove(hash ...*felt.Felt) error { return errors.New("not implemented") } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 5838a1f547..fd8801faeb 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -5,14 +5,15 @@ import ( "testing" "time" - "github.com/NethermindEth/juno/blockchain" "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/mempool" + "github.com/NethermindEth/juno/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) func setupDatabase(dltExisting bool) (*db.DB, func(), error) { @@ -39,12 +40,14 @@ func setupDatabase(dltExisting bool) (*db.DB, func(), error) { func TestMempool(t *testing.T) { testDB, dbCloser, err := setupDatabase(true) + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) require.NoError(t, err) defer dbCloser() - pool, closer, err := mempool.New(*testDB, 5) + pool, closer, err := mempool.New(*testDB, state, 4) defer closer() require.NoError(t, err) - blockchain.RegisterCoreTypesToEncoder() l := pool.Len() assert.Equal(t, uint16(0), l) @@ -54,16 +57,19 @@ func TestMempool(t *testing.T) { // push multiple to empty (1,2,3) for i := uint64(1); i < 4; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: senderAddress, + Version: new(core.TransactionVersion).SetUint64(1), }, })) - l := pool.Len() assert.Equal(t, uint16(i), l) } - // consume some (remove 1,2, keep 3) for i := uint64(1); i < 3; i++ { txn, err := pool.Pop() @@ -76,12 +82,16 @@ func TestMempool(t *testing.T) { // push multiple to non empty (push 4,5. now have 3,4,5) for i := uint64(4); i < 6; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: senderAddress, + Version: new(core.TransactionVersion).SetUint64(1), }, })) - l := pool.Len() assert.Equal(t, uint16(i-2), l) } @@ -103,15 +113,16 @@ func TestMempool(t *testing.T) { _, err = pool.Pop() require.Equal(t, err.Error(), "transaction pool is empty") - } func TestRestoreMempool(t *testing.T) { - blockchain.RegisterCoreTypesToEncoder() - + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) testDB, _, err := setupDatabase(true) require.NoError(t, err) - pool, closer, err := mempool.New(*testDB, 1024) + + pool, closer, err := mempool.New(*testDB, state, 1024) require.NoError(t, err) // Check both pools are empty @@ -122,9 +133,14 @@ func TestRestoreMempool(t *testing.T) { // push multiple transactions to empty mempool (1,2,3) for i := uint64(1); i < 4; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), + Version: new(core.TransactionVersion).SetUint64(1), + SenderAddress: senderAddress, + Nonce: new(felt.Felt).SetUint64(0), }, })) assert.Equal(t, uint16(i), pool.Len()) @@ -142,7 +158,7 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) defer dbCloser() - poolRestored, closer2, err := mempool.New(*testDB, 1024) + poolRestored, closer2, err := mempool.New(*testDB, state, 1024) require.NoError(t, err) lenDB, err = poolRestored.LenDB() require.NoError(t, err) @@ -163,9 +179,11 @@ func TestRestoreMempool(t *testing.T) { func TestWait(t *testing.T) { testDB := pebble.NewMemTest(t) - pool, _, err := mempool.New(testDB, 1024) + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) + pool, _, err := mempool.New(testDB, state, 1024) require.NoError(t, err) - blockchain.RegisterCoreTypesToEncoder() select { case <-pool.Wait(): @@ -174,22 +192,34 @@ func TestWait(t *testing.T) { } // One transaction. + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(1)).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(1), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(1), + Version: new(core.TransactionVersion).SetUint64(1), }, })) <-pool.Wait() // Two transactions. + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(2)).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(2), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(2), + Version: new(core.TransactionVersion).SetUint64(1), }, })) + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(3)).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(3), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(3), + Version: new(core.TransactionVersion).SetUint64(1), }, })) <-pool.Wait() From a28a24f53e2a55e3fb132a5e63fb896c6d236bd4 Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 8 Jan 2025 16:06:45 +0200 Subject: [PATCH 10/22] implement buckets for mempool db --- mempool/bucket.go | 21 +++++++++++++++++++++ mempool/mempool.go | 26 +++++++++++--------------- 2 files changed, 32 insertions(+), 15 deletions(-) create mode 100644 mempool/bucket.go diff --git a/mempool/bucket.go b/mempool/bucket.go new file mode 100644 index 0000000000..9d9ef10c01 --- /dev/null +++ b/mempool/bucket.go @@ -0,0 +1,21 @@ +package mempool + +import "slices" + +//go:generate go run github.com/dmarkham/enumer -type=Bucket -output=buckets_enumer.go +type Bucket byte + +// Pebble does not support buckets to differentiate between groups of +// keys like Bolt or MDBX does. We use a global prefix list as a poor +// man's bucket alternative. +const ( + Head Bucket = iota // key of the head node + Tail // key of the tail node + Length // number of transactions + Node +) + +// Key flattens a prefix and series of byte arrays into a single []byte. +func (b Bucket) Key(key ...[]byte) []byte { + return append([]byte{byte(b)}, slices.Concat(key...)...) +} diff --git a/mempool/mempool.go b/mempool/mempool.go index db243f1f10..738b471d70 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -12,12 +12,6 @@ import ( "github.com/NethermindEth/juno/encoder" ) -const ( - poolLengthKey = "poolLength" - headKey = "headKey" - tailKey = "tailKey" -) - var ErrTxnPoolFull = errors.New("transaction pool is full") type storageElem struct { @@ -318,7 +312,7 @@ func (p *Pool) LenDB() (uint16, error) { func (p *Pool) lenDB(txn db.Transaction) (uint16, error) { var l uint16 - err := txn.Get([]byte(poolLengthKey), func(b []byte) error { + err := txn.Get(Length.Key(), func(b []byte) error { l = binary.BigEndian.Uint16(b) return nil }) @@ -330,7 +324,7 @@ func (p *Pool) lenDB(txn db.Transaction) (uint16, error) { } func (p *Pool) updateLen(txn db.Transaction, l uint16) error { - return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint16(nil, l)) + return txn.Set(Length.Key(), binary.BigEndian.AppendUint16(nil, l)) } func (p *Pool) Wait() <-chan struct{} { @@ -338,7 +332,7 @@ func (p *Pool) Wait() <-chan struct{} { } func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { - return txn.Get([]byte(headKey), func(b []byte) error { + return txn.Get(Head.Key(), func(b []byte) error { head.SetBytes(b) return nil }) @@ -350,7 +344,7 @@ func (p *Pool) HeadHash() (*felt.Felt, error) { return nil, err } var head *felt.Felt - err = txn.Get([]byte(headKey), func(b []byte) error { + err = txn.Get(Head.Key(), func(b []byte) error { head = new(felt.Felt).SetBytes(b) return nil }) @@ -358,23 +352,24 @@ func (p *Pool) HeadHash() (*felt.Felt, error) { } func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { - return txn.Set([]byte(headKey), head.Marshal()) + return txn.Set(Head.Key(), head.Marshal()) } func (p *Pool) tailValue(txn db.Transaction, tail *felt.Felt) error { - return txn.Get([]byte(tailKey), func(b []byte) error { + return txn.Get(Tail.Key(), func(b []byte) error { tail.SetBytes(b) return nil }) } func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { - return txn.Set([]byte(tailKey), tail.Marshal()) + return txn.Set(Tail.Key(), tail.Marshal()) } func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem - err := txn.Get(itemKey.Marshal(), func(b []byte) error { + keyBytes := itemKey.Bytes() + err := txn.Get(Node.Key(keyBytes[:]), func(b []byte) error { return encoder.Unmarshal(b, &item) }) return item, err @@ -385,5 +380,6 @@ func (p *Pool) putdbElem(txn db.Transaction, itemKey *felt.Felt, item *storageEl if err != nil { return err } - return txn.Set(itemKey.Marshal(), itemBytes) + keyBytes := itemKey.Bytes() + return txn.Set(Node.Key(keyBytes[:]), itemBytes) } From 4f752fe90b0d2fdcb3e280857b701cfb2734761b Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 9 Jan 2025 10:46:29 +0200 Subject: [PATCH 11/22] fix lint + tests --- mempool/init_test.go | 5 --- mempool/mempool.go | 69 +++++++++++++++--------------------- mempool/mempool_test.go | 78 ++++++++++++++++++++++------------------- 3 files changed, 70 insertions(+), 82 deletions(-) delete mode 100644 mempool/init_test.go diff --git a/mempool/init_test.go b/mempool/init_test.go deleted file mode 100644 index 3104ea09e7..0000000000 --- a/mempool/init_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package mempool_test - -import ( - _ "github.com/NethermindEth/juno/encoder/registry" -) diff --git a/mempool/mempool.go b/mempool/mempool.go index 738b471d70..39fdd02d3b 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/encoder" + "github.com/NethermindEth/juno/utils" ) var ErrTxnPoolFull = errors.New("transaction pool is full") @@ -34,6 +35,7 @@ type txnList struct { // Pool stores the transactions in a linked list for its inherent FCFS behaviour type Pool struct { + log utils.SimpleLogger state core.StateReader db db.DB // persistent mempool txPushed chan struct{} @@ -43,12 +45,13 @@ type Pool struct { wg sync.WaitGroup } -// New initializes the Pool and starts the database writer goroutine. +// New initialises the Pool and starts the database writer goroutine. // It is the responsibility of the user to call the cancel function if the context is cancelled -func New(db db.DB, state core.StateReader, maxNumTxns uint16) (*Pool, func() error, error) { +func New(persistentPool db.DB, state core.StateReader, maxNumTxns uint16, log utils.SimpleLogger) (*Pool, func() error, error) { pool := &Pool{ + log: log, state: state, - db: db, // todo: txns should be deleted everytime a new block is stored (builder responsibility) + db: persistentPool, // todo: txns should be deleted everytime a new block is stored (builder responsibility) txPushed: make(chan struct{}, 1), txnList: &txnList{}, maxNumTxns: maxNumTxns, @@ -56,7 +59,7 @@ func New(db db.DB, state core.StateReader, maxNumTxns uint16) (*Pool, func() err } if err := pool.loadFromDB(); err != nil { - return nil, nil, fmt.Errorf("failed to load transactions from database into the in-memory transaction list: %v\n", err) + return nil, nil, fmt.Errorf("failed to load transactions from database into the in-memory transaction list: %v", err) } pool.wg.Add(1) @@ -74,25 +77,20 @@ func New(db db.DB, state core.StateReader, maxNumTxns uint16) (*Pool, func() err func (p *Pool) dbWriter() { defer p.wg.Done() - for { - select { - case txn, ok := <-p.dbWriteChan: - if !ok { - return - } - p.handleTransaction(txn) - } + for txn := range p.dbWriteChan { + err := p.handleTransaction(txn) + p.log.Errorw("error in handling user transaction in persistent mempool", "err", err) } } // loadFromDB restores the in-memory transaction pool from the database func (p *Pool) loadFromDB() error { return p.db.View(func(txn db.Transaction) error { - len, err := p.LenDB() + lenDB, err := p.LenDB() if err != nil { return err } - if len >= p.maxNumTxns { + if lenDB >= p.maxNumTxns { return ErrTxnPoolFull } headValue := new(felt.Felt) @@ -152,13 +150,11 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { } tailValue = nil } - if err := p.putdbElem(dbTxn, userTxn.Transaction.Hash(), &storageElem{ Txn: *userTxn, }); err != nil { return err } - if tailValue != nil { // Update old tail to point to the new item var oldTailElem storageElem @@ -176,16 +172,14 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { return err } } - if err := p.updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } - pLen, err := p.lenDB(dbTxn) if err != nil { return err } - return p.updateLen(dbTxn, uint16(pLen+1)) + return p.updateLen(dbTxn, pLen+1) }) } @@ -196,18 +190,17 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { return err } - // todo: should db overloading block the in-memory mempool?? select { case p.dbWriteChan <- userTxn: default: select { case _, ok := <-p.dbWriteChan: if !ok { - return errors.New("transaction pool database write channel is closed") + p.log.Errorw("cannot store user transasction in persistent pool, database write channel is closed") } - return ErrTxnPoolFull + p.log.Errorw("cannot store user transasction in persistent pool, database is full") default: - return ErrTxnPoolFull + p.log.Errorw("cannot store user transasction in persistent pool, database is full") } } @@ -232,7 +225,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { } func (p *Pool) validate(userTxn *BroadcastedTransaction) error { - if p.txnList.len+1 >= uint16(p.maxNumTxns) { + if p.txnList.len+1 >= p.maxNumTxns { return ErrTxnPoolFull } @@ -246,7 +239,7 @@ func (p *Pool) validate(userTxn *BroadcastedTransaction) error { case *core.DeclareTransaction: nonce, err := p.state.ContractNonce(t.SenderAddress) if err != nil { - return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err) + return fmt.Errorf("validation failed, error when retrieving nonce, %v", err) } if nonce.Cmp(t.Nonce) > 0 { return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) @@ -257,7 +250,7 @@ func (p *Pool) validate(userTxn *BroadcastedTransaction) error { } nonce, err := p.state.ContractNonce(t.SenderAddress) if err != nil { - return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err) + return fmt.Errorf("validation failed, error when retrieving nonce, %v", err) } if nonce.Cmp(t.Nonce) > 0 { return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) @@ -302,12 +295,17 @@ func (p *Pool) Len() uint16 { // Len returns the number of transactions in the persistent pool func (p *Pool) LenDB() (uint16, error) { + p.wg.Add(1) + defer p.wg.Done() txn, err := p.db.NewTransaction(false) if err != nil { return 0, err } - defer txn.Discard() - return p.lenDB(txn) + lenDB, err := p.lenDB(txn) + if err != nil { + return 0, err + } + return lenDB, txn.Discard() } func (p *Pool) lenDB(txn db.Transaction) (uint16, error) { @@ -338,19 +336,6 @@ func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { }) } -func (p *Pool) HeadHash() (*felt.Felt, error) { - txn, err := p.db.NewTransaction(false) - if err != nil { - return nil, err - } - var head *felt.Felt - err = txn.Get(Head.Key(), func(b []byte) error { - head = new(felt.Felt).SetBytes(b) - return nil - }) - return head, err -} - func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { return txn.Set(Head.Key(), head.Marshal()) } @@ -366,6 +351,8 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set(Tail.Key(), tail.Marshal()) } +// todo : error when unmarshalling the core.Transasction... +// but unmarshalling core.Transaction works fine in TransactionsByBlockNumber... func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem keyBytes := itemKey.Bytes() diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index fd8801faeb..72c6d0d9dd 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -9,15 +9,15 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" + _ "github.com/NethermindEth/juno/encoder/registry" "github.com/NethermindEth/juno/mempool" "github.com/NethermindEth/juno/mocks" - "github.com/stretchr/testify/assert" + "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -func setupDatabase(dltExisting bool) (*db.DB, func(), error) { - dbPath := "testmempool" +func setupDatabase(dbPath string, dltExisting bool) (db.DB, func(), error) { if _, err := os.Stat(dbPath); err == nil { if dltExisting { if err := os.RemoveAll(dbPath); err != nil { @@ -27,7 +27,7 @@ func setupDatabase(dltExisting bool) (*db.DB, func(), error) { } else if !os.IsNotExist(err) { return nil, nil, err } - db, err := pebble.New(dbPath) + persistentPool, err := pebble.New(dbPath) if err != nil { return nil, nil, err } @@ -35,22 +35,22 @@ func setupDatabase(dltExisting bool) (*db.DB, func(), error) { // The db should be closed by the mempool closer function os.RemoveAll(dbPath) } - return &db, closer, nil + return persistentPool, closer, nil } func TestMempool(t *testing.T) { - testDB, dbCloser, err := setupDatabase(true) + testDB, dbCloser, err := setupDatabase("testmempool", true) + log := utils.NewNopZapLogger() mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) state := mocks.NewMockStateHistoryReader(mockCtrl) require.NoError(t, err) defer dbCloser() - pool, closer, err := mempool.New(*testDB, state, 4) - defer closer() + pool, closer, err := mempool.New(testDB, state, 4, log) require.NoError(t, err) l := pool.Len() - assert.Equal(t, uint16(0), l) + require.Equal(t, uint16(0), l) _, err = pool.Pop() require.Equal(t, err.Error(), "transaction pool is empty") @@ -59,7 +59,7 @@ func TestMempool(t *testing.T) { for i := uint64(1); i < 4; i++ { senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) - assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), Nonce: new(felt.Felt).SetUint64(1), @@ -68,23 +68,23 @@ func TestMempool(t *testing.T) { }, })) l := pool.Len() - assert.Equal(t, uint16(i), l) + require.Equal(t, uint16(i), l) } // consume some (remove 1,2, keep 3) for i := uint64(1); i < 3; i++ { txn, err := pool.Pop() require.NoError(t, err) - assert.Equal(t, i, txn.Transaction.Hash().Uint64()) + require.Equal(t, i, txn.Transaction.Hash().Uint64()) l := pool.Len() - assert.Equal(t, uint16(3-i), l) + require.Equal(t, uint16(3-i), l) } // push multiple to non empty (push 4,5. now have 3,4,5) for i := uint64(4); i < 6; i++ { senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) - assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), Nonce: new(felt.Felt).SetUint64(1), @@ -93,11 +93,11 @@ func TestMempool(t *testing.T) { }, })) l := pool.Len() - assert.Equal(t, uint16(i-2), l) + require.Equal(t, uint16(i-2), l) } // push more than max - assert.ErrorIs(t, pool.Push(&mempool.BroadcastedTransaction{ + require.ErrorIs(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(123), }, @@ -107,35 +107,39 @@ func TestMempool(t *testing.T) { for i := uint64(3); i < 6; i++ { txn, err := pool.Pop() require.NoError(t, err) - assert.Equal(t, i, txn.Transaction.Hash().Uint64()) + require.Equal(t, i, txn.Transaction.Hash().Uint64()) } - assert.Equal(t, uint16(0), l) + require.Equal(t, uint16(0), l) _, err = pool.Pop() require.Equal(t, err.Error(), "transaction pool is empty") + require.NoError(t, closer()) } func TestRestoreMempool(t *testing.T) { + log := utils.NewNopZapLogger() + mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) state := mocks.NewMockStateHistoryReader(mockCtrl) - testDB, _, err := setupDatabase(true) + testDB, dbCloser, err := setupDatabase("testrestoremempool", true) require.NoError(t, err) + defer dbCloser() - pool, closer, err := mempool.New(*testDB, state, 1024) + pool, closer, err := mempool.New(testDB, state, 1024, log) require.NoError(t, err) // Check both pools are empty lenDB, err := pool.LenDB() require.NoError(t, err) - assert.Equal(t, uint16(0), lenDB) - assert.Equal(t, uint16(0), pool.Len()) + require.Equal(t, uint16(0), lenDB) + require.Equal(t, uint16(0), pool.Len()) // push multiple transactions to empty mempool (1,2,3) for i := uint64(1); i < 4; i++ { senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) - assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), Version: new(core.TransactionVersion).SetUint64(1), @@ -143,27 +147,25 @@ func TestRestoreMempool(t *testing.T) { Nonce: new(felt.Felt).SetUint64(0), }, })) - assert.Equal(t, uint16(i), pool.Len()) + require.Equal(t, uint16(i), pool.Len()) } // check the db has stored the transactions time.Sleep(100 * time.Millisecond) lenDB, err = pool.LenDB() require.NoError(t, err) - assert.Equal(t, uint16(3), lenDB) - + require.Equal(t, uint16(3), lenDB) // Close the mempool require.NoError(t, closer()) - testDB, dbCloser, err := setupDatabase(false) + testDB, _, err = setupDatabase("testrestoremempool", false) require.NoError(t, err) - defer dbCloser() - poolRestored, closer2, err := mempool.New(*testDB, state, 1024) + poolRestored, closer2, err := mempool.New(testDB, state, 1024, log) require.NoError(t, err) lenDB, err = poolRestored.LenDB() require.NoError(t, err) - assert.Equal(t, uint16(3), lenDB) - assert.Equal(t, uint16(3), poolRestored.Len()) + require.Equal(t, uint16(3), lenDB) + require.Equal(t, uint16(3), poolRestored.Len()) // Remove transactions _, err = poolRestored.Pop() @@ -171,18 +173,22 @@ func TestRestoreMempool(t *testing.T) { _, err = poolRestored.Pop() require.NoError(t, err) lenDB, err = poolRestored.LenDB() - assert.Equal(t, uint16(3), lenDB) - assert.Equal(t, uint16(1), poolRestored.Len()) + require.NoError(t, err) + require.Equal(t, uint16(3), lenDB) + require.Equal(t, uint16(1), poolRestored.Len()) - closer2() + require.NoError(t, closer2()) } func TestWait(t *testing.T) { - testDB := pebble.NewMemTest(t) + log := utils.NewNopZapLogger() + testDB, dbCloser, err := setupDatabase("testwait", true) + require.NoError(t, err) + defer dbCloser() mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) state := mocks.NewMockStateHistoryReader(mockCtrl) - pool, _, err := mempool.New(testDB, state, 1024) + pool, _, err := mempool.New(testDB, state, 1024, log) require.NoError(t, err) select { From 21aae71766fd12f8ac95d30fe3454179b164c97b Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 9 Jan 2025 16:11:07 +0200 Subject: [PATCH 12/22] comment: len to int --- mempool/mempool.go | 31 ++++++++++++------------------- mempool/mempool_test.go | 31 +++++++++++++------------------ 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 39fdd02d3b..9b3a42360e 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -1,9 +1,9 @@ package mempool import ( - "encoding/binary" "errors" "fmt" + "math/big" "sync" "github.com/NethermindEth/juno/core" @@ -29,7 +29,7 @@ type BroadcastedTransaction struct { type txnList struct { head *storageElem tail *storageElem - len uint16 + len int mu sync.Mutex } @@ -40,14 +40,14 @@ type Pool struct { db db.DB // persistent mempool txPushed chan struct{} txnList *txnList // in-memory - maxNumTxns uint16 + maxNumTxns int dbWriteChan chan *BroadcastedTransaction wg sync.WaitGroup } // New initialises the Pool and starts the database writer goroutine. // It is the responsibility of the user to call the cancel function if the context is cancelled -func New(persistentPool db.DB, state core.StateReader, maxNumTxns uint16, log utils.SimpleLogger) (*Pool, func() error, error) { +func New(persistentPool db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error, error) { pool := &Pool{ log: log, state: state, @@ -86,15 +86,8 @@ func (p *Pool) dbWriter() { // loadFromDB restores the in-memory transaction pool from the database func (p *Pool) loadFromDB() error { return p.db.View(func(txn db.Transaction) error { - lenDB, err := p.LenDB() - if err != nil { - return err - } - if lenDB >= p.maxNumTxns { - return ErrTxnPoolFull - } headValue := new(felt.Felt) - err = p.headHash(txn, headValue) + err := p.headHash(txn, headValue) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { return nil @@ -289,12 +282,12 @@ func (p *Pool) Remove(hash ...*felt.Felt) error { } // Len returns the number of transactions in the in-memory pool -func (p *Pool) Len() uint16 { +func (p *Pool) Len() int { return p.txnList.len } // Len returns the number of transactions in the persistent pool -func (p *Pool) LenDB() (uint16, error) { +func (p *Pool) LenDB() (int, error) { p.wg.Add(1) defer p.wg.Done() txn, err := p.db.NewTransaction(false) @@ -308,10 +301,10 @@ func (p *Pool) LenDB() (uint16, error) { return lenDB, txn.Discard() } -func (p *Pool) lenDB(txn db.Transaction) (uint16, error) { - var l uint16 +func (p *Pool) lenDB(txn db.Transaction) (int, error) { + var l int err := txn.Get(Length.Key(), func(b []byte) error { - l = binary.BigEndian.Uint16(b) + l = int(new(big.Int).SetBytes(b).Int64()) return nil }) @@ -321,8 +314,8 @@ func (p *Pool) lenDB(txn db.Transaction) (uint16, error) { return l, err } -func (p *Pool) updateLen(txn db.Transaction, l uint16) error { - return txn.Set(Length.Key(), binary.BigEndian.AppendUint16(nil, l)) +func (p *Pool) updateLen(txn db.Transaction, l int) error { + return txn.Set(Length.Key(), new(big.Int).SetInt64(int64(l)).Bytes()) } func (p *Pool) Wait() <-chan struct{} { diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 72c6d0d9dd..d67869fd88 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -49,8 +49,7 @@ func TestMempool(t *testing.T) { pool, closer, err := mempool.New(testDB, state, 4, log) require.NoError(t, err) - l := pool.Len() - require.Equal(t, uint16(0), l) + require.Equal(t, 0, pool.Len()) _, err = pool.Pop() require.Equal(t, err.Error(), "transaction pool is empty") @@ -67,17 +66,14 @@ func TestMempool(t *testing.T) { Version: new(core.TransactionVersion).SetUint64(1), }, })) - l := pool.Len() - require.Equal(t, uint16(i), l) + require.Equal(t, int(i), pool.Len()) } // consume some (remove 1,2, keep 3) for i := uint64(1); i < 3; i++ { txn, err := pool.Pop() require.NoError(t, err) require.Equal(t, i, txn.Transaction.Hash().Uint64()) - - l := pool.Len() - require.Equal(t, uint16(3-i), l) + require.Equal(t, int(3-i), pool.Len()) } // push multiple to non empty (push 4,5. now have 3,4,5) @@ -92,8 +88,7 @@ func TestMempool(t *testing.T) { Version: new(core.TransactionVersion).SetUint64(1), }, })) - l := pool.Len() - require.Equal(t, uint16(i-2), l) + require.Equal(t, int(i-2), pool.Len()) } // push more than max @@ -109,7 +104,7 @@ func TestMempool(t *testing.T) { require.NoError(t, err) require.Equal(t, i, txn.Transaction.Hash().Uint64()) } - require.Equal(t, uint16(0), l) + require.Equal(t, 0, pool.Len()) _, err = pool.Pop() require.Equal(t, err.Error(), "transaction pool is empty") @@ -132,8 +127,8 @@ func TestRestoreMempool(t *testing.T) { // Check both pools are empty lenDB, err := pool.LenDB() require.NoError(t, err) - require.Equal(t, uint16(0), lenDB) - require.Equal(t, uint16(0), pool.Len()) + require.Equal(t, 0, lenDB) + require.Equal(t, 0, pool.Len()) // push multiple transactions to empty mempool (1,2,3) for i := uint64(1); i < 4; i++ { @@ -147,14 +142,14 @@ func TestRestoreMempool(t *testing.T) { Nonce: new(felt.Felt).SetUint64(0), }, })) - require.Equal(t, uint16(i), pool.Len()) + require.Equal(t, int(i), pool.Len()) } // check the db has stored the transactions time.Sleep(100 * time.Millisecond) lenDB, err = pool.LenDB() require.NoError(t, err) - require.Equal(t, uint16(3), lenDB) + require.Equal(t, 3, lenDB) // Close the mempool require.NoError(t, closer()) testDB, _, err = setupDatabase("testrestoremempool", false) @@ -164,8 +159,8 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) lenDB, err = poolRestored.LenDB() require.NoError(t, err) - require.Equal(t, uint16(3), lenDB) - require.Equal(t, uint16(3), poolRestored.Len()) + require.Equal(t, 3, lenDB) + require.Equal(t, 3, poolRestored.Len()) // Remove transactions _, err = poolRestored.Pop() @@ -174,8 +169,8 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) lenDB, err = poolRestored.LenDB() require.NoError(t, err) - require.Equal(t, uint16(3), lenDB) - require.Equal(t, uint16(1), poolRestored.Len()) + require.Equal(t, 3, lenDB) + require.Equal(t, 1, poolRestored.Len()) require.NoError(t, closer2()) } From 69ff0641459964e6aa4800e7f1dff8e7056db7ec Mon Sep 17 00:00:00 2001 From: rian Date: Thu, 9 Jan 2025 16:30:17 +0200 Subject: [PATCH 13/22] comment: store persistent mempool txns in the main db --- db/buckets.go | 4 ++++ mempool/bucket.go | 21 --------------------- mempool/mempool.go | 22 +++++++++++----------- mempool/mempool_test.go | 4 ++-- 4 files changed, 17 insertions(+), 34 deletions(-) delete mode 100644 mempool/bucket.go diff --git a/db/buckets.go b/db/buckets.go index 2773773f5a..e5037378a3 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -34,6 +34,10 @@ const ( Temporary // used temporarily for migrations SchemaIntermediateState L1HandlerTxnHashByMsgHash // maps l1 handler msg hash to l1 handler txn hash + MempoolHead // key of the head node + MempoolTail // key of the tail node + MempoolLength // number of transactions + MempoolNode ) // Key flattens a prefix and series of byte arrays into a single []byte. diff --git a/mempool/bucket.go b/mempool/bucket.go deleted file mode 100644 index 9d9ef10c01..0000000000 --- a/mempool/bucket.go +++ /dev/null @@ -1,21 +0,0 @@ -package mempool - -import "slices" - -//go:generate go run github.com/dmarkham/enumer -type=Bucket -output=buckets_enumer.go -type Bucket byte - -// Pebble does not support buckets to differentiate between groups of -// keys like Bolt or MDBX does. We use a global prefix list as a poor -// man's bucket alternative. -const ( - Head Bucket = iota // key of the head node - Tail // key of the tail node - Length // number of transactions - Node -) - -// Key flattens a prefix and series of byte arrays into a single []byte. -func (b Bucket) Key(key ...[]byte) []byte { - return append([]byte{byte(b)}, slices.Concat(key...)...) -} diff --git a/mempool/mempool.go b/mempool/mempool.go index 9b3a42360e..9aec92abff 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -37,7 +37,7 @@ type txnList struct { type Pool struct { log utils.SimpleLogger state core.StateReader - db db.DB // persistent mempool + db db.DB // to store the persistent mempool txPushed chan struct{} txnList *txnList // in-memory maxNumTxns int @@ -47,11 +47,11 @@ type Pool struct { // New initialises the Pool and starts the database writer goroutine. // It is the responsibility of the user to call the cancel function if the context is cancelled -func New(persistentPool db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error, error) { +func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error, error) { pool := &Pool{ log: log, state: state, - db: persistentPool, // todo: txns should be deleted everytime a new block is stored (builder responsibility) + db: mainDB, // todo: txns should be deleted everytime a new block is stored (builder responsibility) txPushed: make(chan struct{}, 1), txnList: &txnList{}, maxNumTxns: maxNumTxns, @@ -303,7 +303,7 @@ func (p *Pool) LenDB() (int, error) { func (p *Pool) lenDB(txn db.Transaction) (int, error) { var l int - err := txn.Get(Length.Key(), func(b []byte) error { + err := txn.Get(db.MempoolLength.Key(), func(b []byte) error { l = int(new(big.Int).SetBytes(b).Int64()) return nil }) @@ -315,7 +315,7 @@ func (p *Pool) lenDB(txn db.Transaction) (int, error) { } func (p *Pool) updateLen(txn db.Transaction, l int) error { - return txn.Set(Length.Key(), new(big.Int).SetInt64(int64(l)).Bytes()) + return txn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(l)).Bytes()) } func (p *Pool) Wait() <-chan struct{} { @@ -323,25 +323,25 @@ func (p *Pool) Wait() <-chan struct{} { } func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { - return txn.Get(Head.Key(), func(b []byte) error { + return txn.Get(db.MempoolHead.Key(), func(b []byte) error { head.SetBytes(b) return nil }) } func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { - return txn.Set(Head.Key(), head.Marshal()) + return txn.Set(db.MempoolHead.Key(), head.Marshal()) } func (p *Pool) tailValue(txn db.Transaction, tail *felt.Felt) error { - return txn.Get(Tail.Key(), func(b []byte) error { + return txn.Get(db.MempoolTail.Key(), func(b []byte) error { tail.SetBytes(b) return nil }) } func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { - return txn.Set(Tail.Key(), tail.Marshal()) + return txn.Set(db.MempoolTail.Key(), tail.Marshal()) } // todo : error when unmarshalling the core.Transasction... @@ -349,7 +349,7 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem keyBytes := itemKey.Bytes() - err := txn.Get(Node.Key(keyBytes[:]), func(b []byte) error { + err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { return encoder.Unmarshal(b, &item) }) return item, err @@ -361,5 +361,5 @@ func (p *Pool) putdbElem(txn db.Transaction, itemKey *felt.Felt, item *storageEl return err } keyBytes := itemKey.Bytes() - return txn.Set(Node.Key(keyBytes[:]), itemBytes) + return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index d67869fd88..a61de88d03 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -55,7 +55,7 @@ func TestMempool(t *testing.T) { require.Equal(t, err.Error(), "transaction pool is empty") // push multiple to empty (1,2,3) - for i := uint64(1); i < 4; i++ { + for i := uint64(1); i < 4; i++ { //nolint:dupl senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ @@ -131,7 +131,7 @@ func TestRestoreMempool(t *testing.T) { require.Equal(t, 0, pool.Len()) // push multiple transactions to empty mempool (1,2,3) - for i := uint64(1); i < 4; i++ { + for i := uint64(1); i < 4; i++ { //nolint:dupl senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ From aa93b4b99c86057aebe24a07abc641db1f7769dc Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 10 Jan 2025 14:41:25 +0200 Subject: [PATCH 14/22] comments - docstrings --- mempool/mempool.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 9aec92abff..668e655255 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -15,6 +15,8 @@ import ( var ErrTxnPoolFull = errors.New("transaction pool is full") +// storageElem defines a node for both the +// in-memory and persistent linked-list type storageElem struct { Txn BroadcastedTransaction NextHash *felt.Felt // persistent db @@ -26,6 +28,7 @@ type BroadcastedTransaction struct { DeclaredClass core.Class } +// txnList is the in-memory mempool type txnList struct { head *storageElem tail *storageElem @@ -33,7 +36,8 @@ type txnList struct { mu sync.Mutex } -// Pool stores the transactions in a linked list for its inherent FCFS behaviour +// Pool represents a blockchain mempool, managing transactions using both an +// in-memory and persistent database. type Pool struct { log utils.SimpleLogger state core.StateReader @@ -134,6 +138,7 @@ func (p *Pool) loadFromDB() error { }) } +// handleTransaction adds the transaction to the persistent linked-list db func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { return p.db.Update(func(dbTxn db.Transaction) error { tailValue := new(felt.Felt) @@ -176,7 +181,7 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { }) } -// Push queues a transaction to the pool and adds it to both the in-memory list and DB +// Push queues a transaction to the pool func (p *Pool) Push(userTxn *BroadcastedTransaction) error { err := p.validate(userTxn) if err != nil { @@ -344,8 +349,6 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set(db.MempoolTail.Key(), tail.Marshal()) } -// todo : error when unmarshalling the core.Transasction... -// but unmarshalling core.Transaction works fine in TransactionsByBlockNumber... func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem keyBytes := itemKey.Bytes() From 39c22573f8ae1305535d16d05bef98e274033e1c Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 10 Jan 2025 14:59:51 +0200 Subject: [PATCH 15/22] comment: update New function signature --- mempool/mempool.go | 31 ++++++++++++++----------------- mempool/mempool_test.go | 26 +++++++++++++++----------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 668e655255..1c07d5e7d0 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -50,8 +50,8 @@ type Pool struct { } // New initialises the Pool and starts the database writer goroutine. -// It is the responsibility of the user to call the cancel function if the context is cancelled -func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error, error) { +// It is the responsibility of the caller to execute the closer function. +func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error) { pool := &Pool{ log: log, state: state, @@ -61,13 +61,6 @@ func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleL maxNumTxns: maxNumTxns, dbWriteChan: make(chan *BroadcastedTransaction, maxNumTxns), } - - if err := pool.loadFromDB(); err != nil { - return nil, nil, fmt.Errorf("failed to load transactions from database into the in-memory transaction list: %v", err) - } - - pool.wg.Add(1) - go pool.dbWriter() closer := func() error { close(pool.dbWriteChan) pool.wg.Wait() @@ -76,19 +69,23 @@ func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleL } return nil } - return pool, closer, nil + pool.dbWriter() + return pool, closer } func (p *Pool) dbWriter() { - defer p.wg.Done() - for txn := range p.dbWriteChan { - err := p.handleTransaction(txn) - p.log.Errorw("error in handling user transaction in persistent mempool", "err", err) - } + p.wg.Add(1) + go func() { + defer p.wg.Done() + for txn := range p.dbWriteChan { + err := p.handleTransaction(txn) + p.log.Errorw("error in handling user transaction in persistent mempool", "err", err) + } + }() } -// loadFromDB restores the in-memory transaction pool from the database -func (p *Pool) loadFromDB() error { +// LoadFromDB restores the in-memory transaction pool from the database +func (p *Pool) LoadFromDB() error { return p.db.View(func(txn db.Transaction) error { headValue := new(felt.Felt) err := p.headHash(txn, headValue) diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index a61de88d03..497d37ed46 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -1,6 +1,7 @@ package mempool_test import ( + "fmt" "os" "testing" "time" @@ -46,8 +47,8 @@ func TestMempool(t *testing.T) { state := mocks.NewMockStateHistoryReader(mockCtrl) require.NoError(t, err) defer dbCloser() - pool, closer, err := mempool.New(testDB, state, 4, log) - require.NoError(t, err) + pool, closer := mempool.New(testDB, state, 4, log) + require.NoError(t, pool.LoadFromDB()) require.Equal(t, 0, pool.Len()) @@ -121,9 +122,10 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) defer dbCloser() - pool, closer, err := mempool.New(testDB, state, 1024, log) - require.NoError(t, err) - + pool, closer := mempool.New(testDB, state, 1024, log) + fmt.Println("============") + require.NoError(t, pool.LoadFromDB()) + fmt.Println("============") // Check both pools are empty lenDB, err := pool.LenDB() require.NoError(t, err) @@ -144,7 +146,7 @@ func TestRestoreMempool(t *testing.T) { })) require.Equal(t, int(i), pool.Len()) } - + fmt.Println("============") // check the db has stored the transactions time.Sleep(100 * time.Millisecond) lenDB, err = pool.LenDB() @@ -152,11 +154,13 @@ func TestRestoreMempool(t *testing.T) { require.Equal(t, 3, lenDB) // Close the mempool require.NoError(t, closer()) + testDB, _, err = setupDatabase("testrestoremempool", false) require.NoError(t, err) - poolRestored, closer2, err := mempool.New(testDB, state, 1024, log) - require.NoError(t, err) + poolRestored, closer2 := mempool.New(testDB, state, 1024, log) + time.Sleep(100 * time.Millisecond) + require.NoError(t, pool.LoadFromDB()) lenDB, err = poolRestored.LenDB() require.NoError(t, err) require.Equal(t, 3, lenDB) @@ -171,7 +175,7 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, lenDB) require.Equal(t, 1, poolRestored.Len()) - + fmt.Println("-------------------") require.NoError(t, closer2()) } @@ -183,8 +187,8 @@ func TestWait(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) state := mocks.NewMockStateHistoryReader(mockCtrl) - pool, _, err := mempool.New(testDB, state, 1024, log) - require.NoError(t, err) + pool, _ := mempool.New(testDB, state, 1024, log) + require.NoError(t, pool.LoadFromDB()) select { case <-pool.Wait(): From 5f927df9c907446cf8bf47080fdb2c63bc86ee9d Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 10 Jan 2025 15:52:04 +0200 Subject: [PATCH 16/22] comment: push method for txlist, rename handleTxn --- mempool/mempool.go | 52 ++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 1c07d5e7d0..09c6aa83f1 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -36,6 +36,19 @@ type txnList struct { mu sync.Mutex } +func (t *txnList) push(newNode *storageElem) { + t.mu.Lock() + if t.tail != nil { + t.tail.Next = newNode + t.tail = newNode + } else { + t.head = newNode + t.tail = newNode + } + t.len++ + t.mu.Unlock() +} + // Pool represents a blockchain mempool, managing transactions using both an // in-memory and persistent database. type Pool struct { @@ -78,8 +91,10 @@ func (p *Pool) dbWriter() { go func() { defer p.wg.Done() for txn := range p.dbWriteChan { - err := p.handleTransaction(txn) - p.log.Errorw("error in handling user transaction in persistent mempool", "err", err) + err := p.writeToDB(txn) + if err != nil { + p.log.Errorw("error in handling user transaction in persistent mempool", "err", err) + } } }() } @@ -95,18 +110,16 @@ func (p *Pool) LoadFromDB() error { } return err } - + // loop through the persistent pool and push nodes to the in-memory pool currentHash := headValue for currentHash != nil { curElem, err := p.dbElem(txn, currentHash) if err != nil { return err } - newNode := &storageElem{ Txn: curElem.Txn, } - if curElem.NextHash != nil { nxtElem, err := p.dbElem(txn, curElem.NextHash) if err != nil { @@ -116,27 +129,15 @@ func (p *Pool) LoadFromDB() error { Txn: nxtElem.Txn, } } - - p.txnList.mu.Lock() - if p.txnList.tail != nil { - p.txnList.tail.Next = newNode - p.txnList.tail = newNode - } else { - p.txnList.head = newNode - p.txnList.tail = newNode - } - p.txnList.len++ - p.txnList.mu.Unlock() - + p.txnList.push(newNode) currentHash = curElem.NextHash } - return nil }) } -// handleTransaction adds the transaction to the persistent linked-list db -func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { +// writeToDB adds the transaction to the persistent pool db +func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error { return p.db.Update(func(dbTxn db.Transaction) error { tailValue := new(felt.Felt) if err := p.tailValue(dbTxn, tailValue); err != nil { @@ -199,17 +200,8 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { } } - p.txnList.mu.Lock() newNode := &storageElem{Txn: *userTxn, Next: nil} - if p.txnList.tail != nil { - p.txnList.tail.Next = newNode - p.txnList.tail = newNode - } else { - p.txnList.head = newNode - p.txnList.tail = newNode - } - p.txnList.len++ - p.txnList.mu.Unlock() + p.txnList.push(newNode) select { case p.txPushed <- struct{}{}: From d4fbf0a1db018f744288c3b4acb3d320929a236f Mon Sep 17 00:00:00 2001 From: rian Date: Fri, 10 Jan 2025 16:12:34 +0200 Subject: [PATCH 17/22] comment: ordering + fix test --- mempool/mempool.go | 12 ++++++------ mempool/mempool_test.go | 7 +------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 09c6aa83f1..2fa623feed 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -15,6 +15,11 @@ import ( var ErrTxnPoolFull = errors.New("transaction pool is full") +type BroadcastedTransaction struct { + Transaction core.Transaction + DeclaredClass core.Class +} + // storageElem defines a node for both the // in-memory and persistent linked-list type storageElem struct { @@ -23,11 +28,6 @@ type storageElem struct { Next *storageElem // in-memory } -type BroadcastedTransaction struct { - Transaction core.Transaction - DeclaredClass core.Class -} - // txnList is the in-memory mempool type txnList struct { head *storageElem @@ -38,6 +38,7 @@ type txnList struct { func (t *txnList) push(newNode *storageElem) { t.mu.Lock() + defer t.mu.Unlock() if t.tail != nil { t.tail.Next = newNode t.tail = newNode @@ -46,7 +47,6 @@ func (t *txnList) push(newNode *storageElem) { t.tail = newNode } t.len++ - t.mu.Unlock() } // Pool represents a blockchain mempool, managing transactions using both an diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 497d37ed46..39a71cf644 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -1,7 +1,6 @@ package mempool_test import ( - "fmt" "os" "testing" "time" @@ -123,9 +122,7 @@ func TestRestoreMempool(t *testing.T) { defer dbCloser() pool, closer := mempool.New(testDB, state, 1024, log) - fmt.Println("============") require.NoError(t, pool.LoadFromDB()) - fmt.Println("============") // Check both pools are empty lenDB, err := pool.LenDB() require.NoError(t, err) @@ -146,7 +143,6 @@ func TestRestoreMempool(t *testing.T) { })) require.Equal(t, int(i), pool.Len()) } - fmt.Println("============") // check the db has stored the transactions time.Sleep(100 * time.Millisecond) lenDB, err = pool.LenDB() @@ -160,7 +156,7 @@ func TestRestoreMempool(t *testing.T) { poolRestored, closer2 := mempool.New(testDB, state, 1024, log) time.Sleep(100 * time.Millisecond) - require.NoError(t, pool.LoadFromDB()) + require.NoError(t, poolRestored.LoadFromDB()) lenDB, err = poolRestored.LenDB() require.NoError(t, err) require.Equal(t, 3, lenDB) @@ -175,7 +171,6 @@ func TestRestoreMempool(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, lenDB) require.Equal(t, 1, poolRestored.Len()) - fmt.Println("-------------------") require.NoError(t, closer2()) } From d5efc7780d582dddc5729c338ba6b199cc937b46 Mon Sep 17 00:00:00 2001 From: rian Date: Mon, 13 Jan 2025 17:11:52 +0200 Subject: [PATCH 18/22] comments: rename set/get fns, txnlist, add txnlist.pop --- mempool/mempool.go | 69 +++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 2fa623feed..ffd9624f31 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -28,15 +28,15 @@ type storageElem struct { Next *storageElem // in-memory } -// txnList is the in-memory mempool -type txnList struct { +// memTxnList represents a linked list of user transactions at runtime" +type memTxnList struct { head *storageElem tail *storageElem len int mu sync.Mutex } -func (t *txnList) push(newNode *storageElem) { +func (t *memTxnList) push(newNode *storageElem) { t.mu.Lock() defer t.mu.Unlock() if t.tail != nil { @@ -49,6 +49,23 @@ func (t *txnList) push(newNode *storageElem) { t.len++ } +func (t *memTxnList) pop() (BroadcastedTransaction, error) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.head == nil { + return BroadcastedTransaction{}, errors.New("transaction pool is empty") + } + + headNode := t.head + t.head = headNode.Next + if t.head == nil { + t.tail = nil + } + t.len-- + return headNode.Txn, nil +} + // Pool represents a blockchain mempool, managing transactions using both an // in-memory and persistent database. type Pool struct { @@ -56,7 +73,7 @@ type Pool struct { state core.StateReader db db.DB // to store the persistent mempool txPushed chan struct{} - txnList *txnList // in-memory + memTxnList *memTxnList maxNumTxns int dbWriteChan chan *BroadcastedTransaction wg sync.WaitGroup @@ -70,7 +87,7 @@ func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleL state: state, db: mainDB, // todo: txns should be deleted everytime a new block is stored (builder responsibility) txPushed: make(chan struct{}, 1), - txnList: &txnList{}, + memTxnList: &memTxnList{}, maxNumTxns: maxNumTxns, dbWriteChan: make(chan *BroadcastedTransaction, maxNumTxns), } @@ -113,7 +130,7 @@ func (p *Pool) LoadFromDB() error { // loop through the persistent pool and push nodes to the in-memory pool currentHash := headValue for currentHash != nil { - curElem, err := p.dbElem(txn, currentHash) + curElem, err := p.readDBElem(txn, currentHash) if err != nil { return err } @@ -121,7 +138,7 @@ func (p *Pool) LoadFromDB() error { Txn: curElem.Txn, } if curElem.NextHash != nil { - nxtElem, err := p.dbElem(txn, curElem.NextHash) + nxtElem, err := p.readDBElem(txn, curElem.NextHash) if err != nil { return err } @@ -129,7 +146,7 @@ func (p *Pool) LoadFromDB() error { Txn: nxtElem.Txn, } } - p.txnList.push(newNode) + p.memTxnList.push(newNode) currentHash = curElem.NextHash } return nil @@ -146,20 +163,18 @@ func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error { } tailValue = nil } - if err := p.putdbElem(dbTxn, userTxn.Transaction.Hash(), &storageElem{ - Txn: *userTxn, - }); err != nil { + if err := p.setDBElem(dbTxn, &storageElem{Txn: *userTxn}); err != nil { return err } if tailValue != nil { // Update old tail to point to the new item var oldTailElem storageElem - oldTailElem, err := p.dbElem(dbTxn, tailValue) + oldTailElem, err := p.readDBElem(dbTxn, tailValue) if err != nil { return err } oldTailElem.NextHash = userTxn.Transaction.Hash() - if err = p.putdbElem(dbTxn, tailValue, &oldTailElem); err != nil { + if err = p.setDBElem(dbTxn, &oldTailElem); err != nil { return err } } else { @@ -201,7 +216,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { } newNode := &storageElem{Txn: *userTxn, Next: nil} - p.txnList.push(newNode) + p.memTxnList.push(newNode) select { case p.txPushed <- struct{}{}: @@ -212,7 +227,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { } func (p *Pool) validate(userTxn *BroadcastedTransaction) error { - if p.txnList.len+1 >= p.maxNumTxns { + if p.memTxnList.len+1 >= p.maxNumTxns { return ErrTxnPoolFull } @@ -251,21 +266,7 @@ func (p *Pool) validate(userTxn *BroadcastedTransaction) error { // Pop returns the transaction with the highest priority from the in-memory pool func (p *Pool) Pop() (BroadcastedTransaction, error) { - p.txnList.mu.Lock() - defer p.txnList.mu.Unlock() - - if p.txnList.head == nil { - return BroadcastedTransaction{}, errors.New("transaction pool is empty") - } - - headNode := p.txnList.head - p.txnList.head = headNode.Next - if p.txnList.head == nil { - p.txnList.tail = nil - } - p.txnList.len-- - - return headNode.Txn, nil + return p.memTxnList.pop() } // Remove removes a set of transactions from the pool @@ -277,7 +278,7 @@ func (p *Pool) Remove(hash ...*felt.Felt) error { // Len returns the number of transactions in the in-memory pool func (p *Pool) Len() int { - return p.txnList.len + return p.memTxnList.len } // Len returns the number of transactions in the persistent pool @@ -338,7 +339,7 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set(db.MempoolTail.Key(), tail.Marshal()) } -func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { +func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem keyBytes := itemKey.Bytes() err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { @@ -347,11 +348,11 @@ func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, erro return item, err } -func (p *Pool) putdbElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem) error { +func (p *Pool) setDBElem(txn db.Transaction, item *storageElem) error { itemBytes, err := encoder.Marshal(item) if err != nil { return err } - keyBytes := itemKey.Bytes() + keyBytes := item.Txn.Transaction.Hash().Bytes() return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) } From 1f5a50dd462647a22255b879ed69175888ee36f9 Mon Sep 17 00:00:00 2001 From: rian Date: Tue, 14 Jan 2025 13:51:44 +0200 Subject: [PATCH 19/22] comment: split runtime and persistent types --- mempool/mempool.go | 52 +++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index ffd9624f31..5a615b8270 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -20,23 +20,27 @@ type BroadcastedTransaction struct { DeclaredClass core.Class } -// storageElem defines a node for both the -// in-memory and persistent linked-list -type storageElem struct { +// runtime mempool txn +type memPoolTxn struct { + Txn BroadcastedTransaction + Next *memPoolTxn +} + +// persistent db txn value +type dbPoolTxn struct { Txn BroadcastedTransaction - NextHash *felt.Felt // persistent db - Next *storageElem // in-memory + NextHash *felt.Felt } -// memTxnList represents a linked list of user transactions at runtime" +// memTxnList represents a linked list of user transactions at runtime type memTxnList struct { - head *storageElem - tail *storageElem + head *memPoolTxn + tail *memPoolTxn len int mu sync.Mutex } -func (t *memTxnList) push(newNode *storageElem) { +func (t *memTxnList) push(newNode *memPoolTxn) { t.mu.Lock() defer t.mu.Unlock() if t.tail != nil { @@ -130,24 +134,24 @@ func (p *Pool) LoadFromDB() error { // loop through the persistent pool and push nodes to the in-memory pool currentHash := headValue for currentHash != nil { - curElem, err := p.readDBElem(txn, currentHash) + curDBElem, err := p.readDBElem(txn, currentHash) if err != nil { return err } - newNode := &storageElem{ - Txn: curElem.Txn, + newMemPoolTxn := &memPoolTxn{ + Txn: curDBElem.Txn, } - if curElem.NextHash != nil { - nxtElem, err := p.readDBElem(txn, curElem.NextHash) + if curDBElem.NextHash != nil { + nextDBTxn, err := p.readDBElem(txn, curDBElem.NextHash) if err != nil { return err } - newNode.Next = &storageElem{ - Txn: nxtElem.Txn, + newMemPoolTxn.Next = &memPoolTxn{ + Txn: nextDBTxn.Txn, } } - p.memTxnList.push(newNode) - currentHash = curElem.NextHash + p.memTxnList.push(newMemPoolTxn) + currentHash = curDBElem.NextHash } return nil }) @@ -163,12 +167,12 @@ func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error { } tailValue = nil } - if err := p.setDBElem(dbTxn, &storageElem{Txn: *userTxn}); err != nil { + if err := p.setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil { return err } if tailValue != nil { // Update old tail to point to the new item - var oldTailElem storageElem + var oldTailElem dbPoolTxn oldTailElem, err := p.readDBElem(dbTxn, tailValue) if err != nil { return err @@ -215,7 +219,7 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { } } - newNode := &storageElem{Txn: *userTxn, Next: nil} + newNode := &memPoolTxn{Txn: *userTxn, Next: nil} p.memTxnList.push(newNode) select { @@ -339,8 +343,8 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set(db.MempoolTail.Key(), tail.Marshal()) } -func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { - var item storageElem +func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { + var item dbPoolTxn keyBytes := itemKey.Bytes() err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { return encoder.Unmarshal(b, &item) @@ -348,7 +352,7 @@ func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, return item, err } -func (p *Pool) setDBElem(txn db.Transaction, item *storageElem) error { +func (p *Pool) setDBElem(txn db.Transaction, item *dbPoolTxn) error { itemBytes, err := encoder.Marshal(item) if err != nil { return err From 6ff70d93aa21edc11a80278782f537d1d77704fb Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 15 Jan 2025 11:59:26 +0200 Subject: [PATCH 20/22] comments: db_utils.go, inline, felt.Zero --- mempool/db_utils.go | 63 +++++++++++++++++++++++++ mempool/mempool.go | 102 +++++++++------------------------------- mempool/mempool_test.go | 4 +- 3 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 mempool/db_utils.go diff --git a/mempool/db_utils.go b/mempool/db_utils.go new file mode 100644 index 0000000000..1638910464 --- /dev/null +++ b/mempool/db_utils.go @@ -0,0 +1,63 @@ +package mempool + +import ( + "errors" + "math/big" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/encoder" +) + +func headValue(txn db.Transaction, head *felt.Felt) error { + return txn.Get(db.MempoolHead.Key(), func(b []byte) error { + head.SetBytes(b) + return nil + }) +} + +func tailValue(txn db.Transaction, tail *felt.Felt) error { + return txn.Get(db.MempoolTail.Key(), func(b []byte) error { + tail.SetBytes(b) + return nil + }) +} + +func updateHead(txn db.Transaction, head *felt.Felt) error { + return txn.Set(db.MempoolHead.Key(), head.Marshal()) +} + +func updateTail(txn db.Transaction, tail *felt.Felt) error { + return txn.Set(db.MempoolTail.Key(), tail.Marshal()) +} + +func readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { + var item dbPoolTxn + keyBytes := itemKey.Bytes() + err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { + return encoder.Unmarshal(b, &item) + }) + return item, err +} + +func setDBElem(txn db.Transaction, item *dbPoolTxn) error { + itemBytes, err := encoder.Marshal(item) + if err != nil { + return err + } + keyBytes := item.Txn.Transaction.Hash().Bytes() + return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) +} + +func lenDB(txn db.Transaction) (int, error) { + var l int + err := txn.Get(db.MempoolLength.Key(), func(b []byte) error { + l = int(new(big.Int).SetBytes(b).Int64()) + return nil + }) + + if err != nil && errors.Is(err, db.ErrKeyNotFound) { + return 0, nil + } + return l, err +} diff --git a/mempool/mempool.go b/mempool/mempool.go index 5a615b8270..4ebb80f7c8 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -9,7 +9,6 @@ import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/encoder" "github.com/NethermindEth/juno/utils" ) @@ -123,8 +122,8 @@ func (p *Pool) dbWriter() { // LoadFromDB restores the in-memory transaction pool from the database func (p *Pool) LoadFromDB() error { return p.db.View(func(txn db.Transaction) error { - headValue := new(felt.Felt) - err := p.headHash(txn, headValue) + headVal := new(felt.Felt) + err := headValue(txn, headVal) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { return nil @@ -132,9 +131,9 @@ func (p *Pool) LoadFromDB() error { return err } // loop through the persistent pool and push nodes to the in-memory pool - currentHash := headValue + currentHash := headVal for currentHash != nil { - curDBElem, err := p.readDBElem(txn, currentHash) + curDBElem, err := readDBElem(txn, currentHash) if err != nil { return err } @@ -142,7 +141,7 @@ func (p *Pool) LoadFromDB() error { Txn: curDBElem.Txn, } if curDBElem.NextHash != nil { - nextDBTxn, err := p.readDBElem(txn, curDBElem.NextHash) + nextDBTxn, err := readDBElem(txn, curDBElem.NextHash) if err != nil { return err } @@ -160,41 +159,41 @@ func (p *Pool) LoadFromDB() error { // writeToDB adds the transaction to the persistent pool db func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error { return p.db.Update(func(dbTxn db.Transaction) error { - tailValue := new(felt.Felt) - if err := p.tailValue(dbTxn, tailValue); err != nil { + tailVal := new(felt.Felt) + if err := tailValue(dbTxn, tailVal); err != nil { if !errors.Is(err, db.ErrKeyNotFound) { return err } - tailValue = nil + tailVal = nil } - if err := p.setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil { + if err := setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil { return err } - if tailValue != nil { + if tailVal != nil { // Update old tail to point to the new item var oldTailElem dbPoolTxn - oldTailElem, err := p.readDBElem(dbTxn, tailValue) + oldTailElem, err := readDBElem(dbTxn, tailVal) if err != nil { return err } oldTailElem.NextHash = userTxn.Transaction.Hash() - if err = p.setDBElem(dbTxn, &oldTailElem); err != nil { + if err = setDBElem(dbTxn, &oldTailElem); err != nil { return err } } else { // Empty list, make new item both the head and the tail - if err := p.updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil { + if err := updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } } - if err := p.updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { + if err := updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { return err } - pLen, err := p.lenDB(dbTxn) + pLen, err := lenDB(dbTxn) if err != nil { return err } - return p.updateLen(dbTxn, pLen+1) + return dbTxn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(pLen+1)).Bytes()) }) } @@ -285,78 +284,19 @@ func (p *Pool) Len() int { return p.memTxnList.len } +func (p *Pool) Wait() <-chan struct{} { + return p.txPushed +} + // Len returns the number of transactions in the persistent pool func (p *Pool) LenDB() (int, error) { - p.wg.Add(1) - defer p.wg.Done() txn, err := p.db.NewTransaction(false) if err != nil { return 0, err } - lenDB, err := p.lenDB(txn) + lenDB, err := lenDB(txn) if err != nil { return 0, err } return lenDB, txn.Discard() } - -func (p *Pool) lenDB(txn db.Transaction) (int, error) { - var l int - err := txn.Get(db.MempoolLength.Key(), func(b []byte) error { - l = int(new(big.Int).SetBytes(b).Int64()) - return nil - }) - - if err != nil && errors.Is(err, db.ErrKeyNotFound) { - return 0, nil - } - return l, err -} - -func (p *Pool) updateLen(txn db.Transaction, l int) error { - return txn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(l)).Bytes()) -} - -func (p *Pool) Wait() <-chan struct{} { - return p.txPushed -} - -func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { - return txn.Get(db.MempoolHead.Key(), func(b []byte) error { - head.SetBytes(b) - return nil - }) -} - -func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { - return txn.Set(db.MempoolHead.Key(), head.Marshal()) -} - -func (p *Pool) tailValue(txn db.Transaction, tail *felt.Felt) error { - return txn.Get(db.MempoolTail.Key(), func(b []byte) error { - tail.SetBytes(b) - return nil - }) -} - -func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { - return txn.Set(db.MempoolTail.Key(), tail.Marshal()) -} - -func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { - var item dbPoolTxn - keyBytes := itemKey.Bytes() - err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { - return encoder.Unmarshal(b, &item) - }) - return item, err -} - -func (p *Pool) setDBElem(txn db.Transaction, item *dbPoolTxn) error { - itemBytes, err := encoder.Marshal(item) - if err != nil { - return err - } - keyBytes := item.Txn.Transaction.Hash().Bytes() - return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) -} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 39a71cf644..57ede7cd32 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -57,7 +57,7 @@ func TestMempool(t *testing.T) { // push multiple to empty (1,2,3) for i := uint64(1); i < 4; i++ { //nolint:dupl senderAddress := new(felt.Felt).SetUint64(i) - state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) + state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), @@ -79,7 +79,7 @@ func TestMempool(t *testing.T) { // push multiple to non empty (push 4,5. now have 3,4,5) for i := uint64(4); i < 6; i++ { senderAddress := new(felt.Felt).SetUint64(i) - state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) + state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ Transaction: &core.InvokeTransaction{ TransactionHash: new(felt.Felt).SetUint64(i), From 5286584a4a09221ee1534e965fefb4545db444ee Mon Sep 17 00:00:00 2001 From: rian Date: Wed, 15 Jan 2025 12:17:34 +0200 Subject: [PATCH 21/22] lint --- mempool/mempool_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 57ede7cd32..21e9ba89e1 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -55,7 +55,7 @@ func TestMempool(t *testing.T) { require.Equal(t, err.Error(), "transaction pool is empty") // push multiple to empty (1,2,3) - for i := uint64(1); i < 4; i++ { //nolint:dupl + for i := uint64(1); i < 4; i++ { senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ @@ -130,7 +130,7 @@ func TestRestoreMempool(t *testing.T) { require.Equal(t, 0, pool.Len()) // push multiple transactions to empty mempool (1,2,3) - for i := uint64(1); i < 4; i++ { //nolint:dupl + for i := uint64(1); i < 4; i++ { senderAddress := new(felt.Felt).SetUint64(i) state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ From cbcfe67fa3dee8a057115090fd5a4612736bc88c Mon Sep 17 00:00:00 2001 From: rian Date: Tue, 21 Jan 2025 11:00:39 +0200 Subject: [PATCH 22/22] comment: rename readTxn, setTxn --- mempool/db_utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mempool/db_utils.go b/mempool/db_utils.go index 1638910464..d3b2003631 100644 --- a/mempool/db_utils.go +++ b/mempool/db_utils.go @@ -31,7 +31,7 @@ func updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set(db.MempoolTail.Key(), tail.Marshal()) } -func readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { +func readTxn(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { var item dbPoolTxn keyBytes := itemKey.Bytes() err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { @@ -40,7 +40,7 @@ func readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { return item, err } -func setDBElem(txn db.Transaction, item *dbPoolTxn) error { +func setTxn(txn db.Transaction, item *dbPoolTxn) error { itemBytes, err := encoder.Marshal(item) if err != nil { return err