From cdc43f6db41fa9784133cd03200fc6005709bc69 Mon Sep 17 00:00:00 2001 From: Ashutosh Jha <66910385+ashu26jha@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:34:40 +0000 Subject: [PATCH] Sort buffered transactions (#1882) Co-authored-by: Kirill Co-authored-by: IronGauntlets --- core/state.go | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/core/state.go b/core/state.go index 905446dc89..96b5c3548a 100644 --- a/core/state.go +++ b/core/state.go @@ -370,6 +370,11 @@ func (s *State) updateStorageBuffered(contractAddr *felt.Felt, updateDiff map[fe func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt]map[felt.Felt]*felt.Felt, blockNumber uint64, logChanges bool, ) error { + type bufferedTransactionWithAddress struct { + txn *db.BufferedTransaction + addr *felt.Felt + } + // make sure all noClassContracts are deployed for addr := range diffs { if _, ok := noClassContracts[addr]; !ok { @@ -400,12 +405,15 @@ func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt }) // update per-contract storage Tries concurrently - contractUpdaters := pool.NewWithResults[*db.BufferedTransaction]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) + contractUpdaters := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) for _, key := range keys { - conractAddr := key - updateDiff := diffs[conractAddr] - contractUpdaters.Go(func() (*db.BufferedTransaction, error) { - return s.updateStorageBuffered(&conractAddr, updateDiff, blockNumber, logChanges) + contractAddr := key + contractUpdaters.Go(func() (*bufferedTransactionWithAddress, error) { + bufferedTxn, err := s.updateStorageBuffered(&contractAddr, diffs[contractAddr], blockNumber, logChanges) + if err != nil { + return nil, err + } + return &bufferedTransactionWithAddress{txn: bufferedTxn, addr: &contractAddr}, nil }) } @@ -414,9 +422,14 @@ func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt return err } + // we sort bufferedTxns in ascending contract address order to achieve an additional speedup + sort.Slice(bufferedTxns, func(i, j int) bool { + return bufferedTxns[i].addr.Cmp(bufferedTxns[j].addr) < 0 + }) + // flush buffered txns - for _, bufferedTxn := range bufferedTxns { - if err = bufferedTxn.Flush(); err != nil { + for _, txnWithAddress := range bufferedTxns { + if err := txnWithAddress.txn.Flush(); err != nil { return err } }