From 484570e605432ab87b31e7e0748dd511357afa76 Mon Sep 17 00:00:00 2001 From: rian Date: Tue, 7 Jan 2025 16:07:52 +0200 Subject: [PATCH] wip - fix tests --- mempool/mempool.go | 31 +++++++++++++++++++++---------- mempool/mempool_test.go | 23 +++++++++++++++++------ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/mempool/mempool.go b/mempool/mempool.go index 364d8111e5..a2b6556dc3 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -91,7 +91,7 @@ func (p *Pool) dbWriter() { // loadFromDB restores the in-memory transaction pool from the database func (p *Pool) loadFromDB() error { return p.db.View(func(txn db.Transaction) error { - var headValue *felt.Felt + headValue := new(felt.Felt) err := p.headHash(txn, headValue) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { @@ -102,7 +102,7 @@ func (p *Pool) loadFromDB() error { currentHash := headValue for currentHash != nil { - curElem, err := p.elem(txn, currentHash) + curElem, err := p.dbElem(txn, currentHash) if err != nil { return err } @@ -112,7 +112,7 @@ func (p *Pool) loadFromDB() error { } if curElem.NextHash != nil { - nxtElem, err := p.elem(txn, curElem.NextHash) + nxtElem, err := p.dbElem(txn, curElem.NextHash) if err != nil { return err } @@ -149,7 +149,7 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { tailValue = nil } - if err := p.putElem(dbTxn, userTxn.Transaction.Hash(), &storageElem{ + if err := p.putdbElem(dbTxn, userTxn.Transaction.Hash(), &storageElem{ Txn: *userTxn, }); err != nil { return err @@ -158,12 +158,12 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error { if tailValue != nil { // Update old tail to point to the new item var oldTailElem storageElem - oldTailElem, err := p.elem(dbTxn, tailValue) + oldTailElem, err := p.dbElem(dbTxn, tailValue) if err != nil { return err } oldTailElem.NextHash = userTxn.Transaction.Hash() - if err = p.putElem(dbTxn, tailValue, &oldTailElem); err != nil { + if err = p.putdbElem(dbTxn, tailValue, &oldTailElem); err != nil { return err } } else { @@ -193,8 +193,6 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error { // todo(rian this PR): validation - // p.handleTransaction(userTxn) - // todo: should db overloading block the in-memory mempool?? select { case p.dbWriteChan <- userTxn: @@ -295,6 +293,19 @@ func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error { }) } +func (p *Pool) HeadHash() (*felt.Felt, error) { + txn, err := p.db.NewTransaction(false) + if err != nil { + return nil, err + } + var head *felt.Felt + err = txn.Get([]byte(headKey), func(b []byte) error { + head = new(felt.Felt).SetBytes(b) + return nil + }) + return head, err +} + func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error { return txn.Set([]byte(headKey), head.Marshal()) } @@ -310,7 +321,7 @@ func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error { return txn.Set([]byte(tailKey), tail.Marshal()) } -func (p *Pool) elem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { +func (p *Pool) dbElem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) { var item storageElem err := txn.Get(itemKey.Marshal(), func(b []byte) error { return encoder.Unmarshal(b, &item) @@ -318,7 +329,7 @@ func (p *Pool) elem(txn db.Transaction, itemKey *felt.Felt) (storageElem, error) return item, err } -func (p *Pool) putElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem) error { +func (p *Pool) putdbElem(txn db.Transaction, itemKey *felt.Felt, item *storageElem) error { itemBytes, err := encoder.Marshal(item) if err != nil { return err diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index c27c92a3df..1560bdb16b 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -120,24 +120,35 @@ func TestRestoreMempool(t *testing.T) { })) assert.Equal(t, uint16(i), pool.Len()) } + // Todo: reads should block?? + fmt.Println(pool.HeadHash()) time.Sleep(100 * time.Millisecond) + fmt.Println(pool.HeadHash()) lenDB, err = pool.LenDB() require.NoError(t, err) assert.Equal(t, uint16(3), lenDB) // Close the mempool require.NoError(t, closer()) - fmt.Println("sdfsdfdf") - _, closer2, err := mempool.New(*testDB, 1024) + + poolRestored, closer2, err := mempool.New(*testDB, 1024) require.NoError(t, err) - fmt.Println("sdfsdfdsssf") - lenDB, err = pool.LenDB() + lenDB, err = poolRestored.LenDB() + require.NoError(t, err) + assert.Equal(t, uint16(3), lenDB) + assert.Equal(t, uint16(3), poolRestored.Len()) + fmt.Println(poolRestored.HeadHash()) + + // Remove transactions + _, err = poolRestored.Pop() + require.NoError(t, err) + _, err = poolRestored.Pop() require.NoError(t, err) + lenDB, err = poolRestored.LenDB() assert.Equal(t, uint16(3), lenDB) - assert.Equal(t, uint16(3), pool.Len()) + assert.Equal(t, uint16(1), poolRestored.Len()) - // // Remove transaction closer2() }