diff --git a/core/trie/proof.go b/core/trie/proof.go index 2b98f96d3a..df3b5d17b9 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -137,7 +137,8 @@ func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSe // - Any node's computed hash doesn't match its expected hash // - The path bits don't match the key bits // - The proof ends before processing all key bits -func VerifyProof(root *felt.Felt, key *Key, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { +func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { + key := FeltToKey(globalTrieHeight, keyFelt) expectedHash := root keyLen := key.Len() @@ -165,7 +166,7 @@ func VerifyProof(root *felt.Felt, key *Key, proof *ProofNodeSet, hash hashFunc) } curPos++ case *Edge: // Edge nodes represent paths between binary nodes - if !verifyEdgePath(key, node.Path, curPos) { + if !verifyEdgePath(&key, node.Path, curPos) { return &felt.Zero, nil } diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index cb6a9d3c5e..7fbdf97154 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -20,8 +20,6 @@ func TestProve(t *testing.T) { tempTrie, records := nonRandomTrie(t, n) for _, record := range records { - key := tempTrie.FeltToKey(record.key) - proofSet := trie.NewProofNodeSet() err := tempTrie.Prove(record.key, proofSet) require.NoError(t, err) @@ -29,9 +27,9 @@ func TestProve(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - val, err := trie.VerifyProof(root, &key, proofSet, crypto.Pedersen) + val, err := trie.VerifyProof(root, record.key, proofSet, crypto.Pedersen) if err != nil { - t.Fatalf("failed for key %s", key.String()) + t.Fatalf("failed for key %s", record.key.String()) } require.Equal(t, record.value, val) } @@ -45,7 +43,6 @@ func TestProveNonExistent(t *testing.T) { for i := 1; i < n+1; i++ { keyFelt := new(felt.Felt).SetUint64(uint64(i + n)) - key := tempTrie.FeltToKey(keyFelt) proofSet := trie.NewProofNodeSet() err := tempTrie.Prove(keyFelt, proofSet) @@ -54,9 +51,9 @@ func TestProveNonExistent(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - val, err := trie.VerifyProof(root, &key, proofSet, crypto.Pedersen) + val, err := trie.VerifyProof(root, keyFelt, proofSet, crypto.Pedersen) if err != nil { - t.Fatalf("failed for key %s", key.String()) + t.Fatalf("failed for key %s", keyFelt.String()) } require.Equal(t, &felt.Zero, val) } @@ -67,8 +64,6 @@ func TestProveRandom(t *testing.T) { tempTrie, records := randomTrie(t, 1000) for _, record := range records { - key := tempTrie.FeltToKey(record.key) - proofSet := trie.NewProofNodeSet() err := tempTrie.Prove(record.key, proofSet) require.NoError(t, err) @@ -76,7 +71,7 @@ func TestProveRandom(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - val, err := trie.VerifyProof(root, &key, proofSet, crypto.Pedersen) + val, err := trie.VerifyProof(root, record.key, proofSet, crypto.Pedersen) require.NoError(t, err) require.Equal(t, record.value, val) } @@ -237,8 +232,7 @@ func TestProveCustom(t *testing.T) { root, err := tr.Root() require.NoError(t, err) - key := tr.FeltToKey(tc.key) - val, err := trie.VerifyProof(root, &key, proofSet, crypto.Pedersen) + val, err := trie.VerifyProof(root, tc.key, proofSet, crypto.Pedersen) require.NoError(t, err) require.Equal(t, tc.expected, val) }) @@ -633,6 +627,67 @@ func TestBadRangeProof(t *testing.T) { } } +func BenchmarkProve(b *testing.B) { + tr, records := randomTrie(b, 1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + proof := trie.NewProofNodeSet() + key := records[i%len(records)].key + if err := tr.Prove(key, proof); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkVerifyProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root, err := tr.Root() + require.NoError(b, err) + + var proofs []*trie.ProofNodeSet + for _, record := range records { + proof := trie.NewProofNodeSet() + if err := tr.Prove(record.key, proof); err != nil { + b.Fatal(err) + } + proofs = append(proofs, proof) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + index := i % len(records) + if _, err := trie.VerifyProof(root, records[index].key, proofs[index], crypto.Pedersen); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkVerifyRangeProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root, err := tr.Root() + require.NoError(b, err) + + start := 2 + end := start + 500 + + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(b, err) + + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := trie.VerifyRangeProof(root, keys[0], keys, values, proof) + require.NoError(b, err) + } +} + func buildTrie(t *testing.T, height uint8, records []*keyValue) *trie.Trie { if len(records) == 0 { t.Fatal("records must have at least one element") @@ -772,7 +827,7 @@ func nonRandomTrie(t *testing.T, numKeys int) (*trie.Trie, []*keyValue) { return tempTrie, records } -func randomTrie(t *testing.T, n int) (*trie.Trie, []*keyValue) { +func randomTrie(t testing.TB, n int) (*trie.Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) memdb := pebble.NewMemTest(t) diff --git a/db/pebble/db.go b/db/pebble/db.go index 5974edf720..77aed603d7 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -60,7 +60,7 @@ func NewMem() (db.DB, error) { } // NewMemTest opens a new in-memory database, panics on error -func NewMemTest(t *testing.T) db.DB { +func NewMemTest(t testing.TB) db.DB { memDB, err := NewMem() if err != nil { t.Fatalf("create in-memory db: %v", err)