Skip to content

Commit

Permalink
wip - testing: db leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
rianhughes committed Jan 7, 2025
1 parent 6fa8dee commit 031de6c
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 102 deletions.
210 changes: 122 additions & 88 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package mempool

import (
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -30,43 +29,63 @@ type BroadcastedTransaction struct {
DeclaredClass core.Class
}

type txnList struct {
head *storageElem
tail *storageElem
len uint16
mu sync.Mutex
}

// Pool stores the transactions in a linked list for its inherent FCFS behaviour
type Pool struct {
db db.DB
txPushed chan struct{}
txnList *txnList // in-memory
maxNumTxns uint16
dbWriteChan chan *BroadcastedTransaction
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}

type txnList struct {
head *storageElem
tail *storageElem
len uint16
mu sync.Mutex
}

// New initializes the Pool and starts the database writer goroutine
func New(ctx context.Context, db db.DB, maxNumTxns uint16) (*Pool, error) {
// New initializes the Pool and starts the database writer goroutine.
// It is the responsibility of the user to call the cancel function if the context is cancelled
func New(db db.DB, maxNumTxns uint16) (*Pool, func() error, error) {
pool := &Pool{
db: db, // todo: txns should be deleted everytime a new block is stored (builder responsibility)
txPushed: make(chan struct{}, 1),
txnList: &txnList{},
maxNumTxns: maxNumTxns,
dbWriteChan: make(chan *BroadcastedTransaction, maxNumTxns),
ctx: ctx,
}

if err := pool.loadFromDB(); err != nil {
return nil, fmt.Errorf("failed to load transactions from database into the in-memory transaction list: %v\n", err)
return nil, nil, fmt.Errorf("failed to load transactions from database into the in-memory transaction list: %v\n", err)
}

pool.wg.Add(1)
go pool.dbWriter()
return pool, nil
closer := func() error {
close(pool.dbWriteChan)
pool.wg.Wait()
// Todo: should uncomment.
// if err := pool.db.Close(); err != nil {
// return fmt.Errorf("failed to close database: %v", err)
// }
return nil
}
return pool, closer, nil
}

func (p *Pool) dbWriter() {
defer p.wg.Done()
for {
select {
case txn, ok := <-p.dbWriteChan:
if !ok {
return
}
p.handleTransaction(txn)
}
}
}

// loadFromDB restores the in-memory transaction pool from the database
Expand All @@ -83,85 +102,87 @@ func (p *Pool) loadFromDB() error {

currentHash := headValue
for currentHash != nil {
storedElem, err := p.elem(txn, currentHash)
curElem, err := p.elem(txn, currentHash)
if err != nil {
return err
}

// Add the transaction to the in-memory linked list
newNode := &storageElem{
Txn: curElem.Txn,
}

if curElem.NextHash != nil {
nxtElem, err := p.elem(txn, curElem.NextHash)
if err != nil {
return err
}
newNode.Next = &storageElem{
Txn: nxtElem.Txn,
}
}

p.txnList.mu.Lock()
newNode := &storageElem{Txn: storedElem.Txn}
if p.txnList.tail != nil {
p.txnList.tail.NextHash = newNode.Txn.Transaction.Hash()
p.txnList.tail.Next = newNode
p.txnList.tail = newNode
} else {
p.txnList.head = newNode
p.txnList.tail = newNode
}
p.txnList.len++
p.txnList.mu.Unlock()
currentHash = storedElem.NextHash

currentHash = curElem.NextHash
}

return nil
})
}

// dbWriter handles writing transactions to the database asynchronously
func (p *Pool) dbWriter() {
defer p.wg.Done()

for {
select {
case <-p.ctx.Done():
return
case txn := <-p.dbWriteChan:
_ = p.db.Update(func(dbTxn db.Transaction) error {
tailValue := new(felt.Felt)
if err := p.tailValue(dbTxn, tailValue); err != nil {
if !errors.Is(err, db.ErrKeyNotFound) {
return err
}
tailValue = nil
}

if err := p.putElem(dbTxn, txn.Transaction.Hash(), &storageElem{
Txn: *txn,
}); err != nil {
return err
}
func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error {
return p.db.Update(func(dbTxn db.Transaction) error {
tailValue := new(felt.Felt)
if err := p.tailValue(dbTxn, tailValue); err != nil {
if !errors.Is(err, db.ErrKeyNotFound) {
return err
}
tailValue = nil
}

if tailValue != nil {
// Update old tail to point to the new item
var oldTailElem storageElem
oldTailElem, err := p.elem(dbTxn, tailValue)
if err != nil {
return err
}
oldTailElem.NextHash = txn.Transaction.Hash()
if err = p.putElem(dbTxn, tailValue, &oldTailElem); err != nil {
return err
}
} else {
// Empty list, make new item both the head and the tail
if err := p.updateHead(dbTxn, txn.Transaction.Hash()); err != nil {
return err
}
}
if err := p.putElem(dbTxn, userTxn.Transaction.Hash(), &storageElem{
Txn: *userTxn,
}); err != nil {
return err
}

if err := p.updateTail(dbTxn, txn.Transaction.Hash()); err != nil {
return err
}
if tailValue != nil {
// Update old tail to point to the new item
var oldTailElem storageElem
oldTailElem, err := p.elem(dbTxn, tailValue)
if err != nil {
return err
}
oldTailElem.NextHash = userTxn.Transaction.Hash()
if err = p.putElem(dbTxn, tailValue, &oldTailElem); err != nil {
return err
}
} else {
// Empty list, make new item both the head and the tail
if err := p.updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil {
return err
}
}

pLen, err := p.len(dbTxn)
if err != nil {
return err
}
return p.updateLen(dbTxn, pLen+1)
})
if err := p.updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil {
return err
}

pLen, err := p.lenDB(dbTxn)
if err != nil {
return err
}
}
return p.updateLen(dbTxn, uint16(pLen+1))
})
}

// Push queues a transaction to the pool and adds it to both the in-memory list and DB
Expand All @@ -170,13 +191,25 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error {
return fmt.Errorf("transaction pool is full")
}

// Todo: push to validate?
// if err := p.rejectDuplicateTxn(userTxn); err != nil {
// return err
// }

// todo(rian this PR): validation

// p.handleTransaction(userTxn)

// todo: should db overloading block the in-memory mempool??
select {
case p.dbWriteChan <- userTxn:
default:
select {
case _, ok := <-p.dbWriteChan:
if !ok {
return errors.New("transaction pool database write channel is closed")
}
return errors.New("transaction pool database write channel is full")
default:
return errors.New("transaction pool database write channel is full")
}
}

p.txnList.mu.Lock()
newNode := &storageElem{Txn: *userTxn, Next: nil}
if p.txnList.tail != nil {
Expand All @@ -194,13 +227,6 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error {
default:
}

// user-txn may still be processed, but it will not make it into the persistent db
select {
case p.dbWriteChan <- userTxn:
default:
return errors.New("transaction pool database write channel is full")
}

return nil
}

Expand Down Expand Up @@ -228,15 +254,23 @@ func (p *Pool) Remove(hash ...*felt.Felt) error {
return errors.New("not implemented")
}

// Len returns the number of transactions in the persistent pool
// Len returns the number of transactions in the in-memory pool
func (p *Pool) Len() uint16 {
return p.txnList.len
}

func (p *Pool) len(txn db.Transaction) (uint64, error) {
var l uint64
func (p *Pool) LenDB() (uint16, error) {
txn, err := p.db.NewTransaction(false)
if err != nil {
return 0, err
}
return p.lenDB(txn)
}

func (p *Pool) lenDB(txn db.Transaction) (uint16, error) {
var l uint16
err := txn.Get([]byte(poolLengthKey), func(b []byte) error {
l = binary.BigEndian.Uint64(b)
l = binary.BigEndian.Uint16(b)
return nil
})

Expand All @@ -246,8 +280,8 @@ func (p *Pool) len(txn db.Transaction) (uint64, error) {
return l, err
}

func (p *Pool) updateLen(txn db.Transaction, l uint64) error {
return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint64(nil, l))
func (p *Pool) updateLen(txn db.Transaction, l uint16) error {
return txn.Set([]byte(poolLengthKey), binary.BigEndian.AppendUint16(nil, l))
}

func (p *Pool) Wait() <-chan struct{} {
Expand Down
Loading

0 comments on commit 031de6c

Please sign in to comment.