From 6b52bbee4e5db4706efa6741fc63d20209177023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Nowosielski?= Date: Thu, 3 Oct 2024 10:14:44 +0200 Subject: [PATCH] getProof rpc method --- blockchain/blockchain.go | 12 + core/state.go | 58 ++++- core/trie/key.go | 3 + core/trie/node.go | 4 +- core/trie/proof.go | 28 +-- core/trie/trie.go | 10 +- mocks/mock_blockchain.go | 16 ++ mocks/mock_trie.go | 104 +++++++++ rpc/contract.go | 19 -- rpc/contract_test.go | 82 ------- rpc/handlers.go | 12 + rpc/storage.go | 297 ++++++++++++++++++++++++ rpc/storage_test.go | 472 +++++++++++++++++++++++++++++++++++++++ 13 files changed, 996 insertions(+), 121 deletions(-) create mode 100644 mocks/mock_trie.go create mode 100644 rpc/storage.go create mode 100644 rpc/storage_test.go diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 4aa6659b98..c9b541de95 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -36,6 +36,7 @@ type Reader interface { StateUpdateByHash(hash *felt.Felt) (update *core.StateUpdate, err error) HeadState() (core.StateReader, StateCloser, error) + HeadTrie() (core.TrieReader, StateCloser, error) StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, StateCloser, error) StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) PendingState() (core.StateReader, StateCloser, error) @@ -769,6 +770,17 @@ func (b *Blockchain) HeadState() (core.StateReader, StateCloser, error) { return core.NewState(txn), txn.Discard, nil } +func (b *Blockchain) HeadTrie() (core.TrieReader, StateCloser, error) { + // Note: I'm not sure I should open a new db txn since the TrieReader is a State + // so the same instance of the state we create in HeadState will do job. + txn, err := b.database.NewTransaction(false) + if err != nil { + return nil, nil, err + } + + return core.NewState(txn), txn.Discard, nil +} + // StateAtBlockNumber returns a StateReader that provides a stable view to the state at the given block number func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockNumber") diff --git a/core/state.go b/core/state.go index effde8b518..d01f2dcd72 100644 --- a/core/state.go +++ b/core/state.go @@ -44,6 +44,17 @@ type StateReader interface { Class(classHash *felt.Felt) (*DeclaredClass, error) } +// TrieReader used for storage proofs, can only be supported by current state implementation (for now, we plan to add db snapshots) +var _ TrieReader = (*State)(nil) + +//go:generate mockgen -destination=../mocks/mock_trie.go -package=mocks github.com/NethermindEth/juno/core TrieReader +type TrieReader interface { + ClassTrie() (*trie.Trie, func() error, error) + StorageTrie() (*trie.Trie, func() error, error) + StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error) + StateAndClassRoot() (*felt.Felt, *felt.Felt, error) +} + type State struct { *history txn db.Transaction @@ -129,6 +140,18 @@ func (s *State) storage() (*trie.Trie, func() error, error) { return s.globalTrie(db.StateTrie, trie.NewTriePedersen) } +func (s *State) StorageTrie() (*trie.Trie, func() error, error) { + return s.storage() +} + +func (s *State) ClassTrie() (*trie.Trie, func() error, error) { + return s.classesTrie() +} + +func (s *State) StorageTrieForAddr(addr *felt.Felt) (*trie.Trie, error) { + return storage(addr, s.txn) +} + func (s *State) classesTrie() (*trie.Trie, func() error, error) { return s.globalTrie(db.ClassesTrie, trie.NewTriePoseidon) } @@ -547,7 +570,7 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { err = s.performStateDeletions(blockNumber, update.StateDiff) if err != nil { - return fmt.Errorf("error performing state deletions: %v", err) + return fmt.Errorf("build reverse diff: %v", err) } stateTrie, storageCloser, err := s.storage() @@ -581,6 +604,7 @@ func (s *State) purgeNoClassContracts() error { // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, // we can use the lack of key's existence as reason for purging noClassContracts. + for addr := range noClassContracts { noClassC, err := NewContractUpdater(&addr, s.txn) if err != nil { @@ -743,3 +767,35 @@ func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error return nil } + +func (s *State) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) { + var storageRoot, classesRoot *felt.Felt + + sStorage, closer, err := s.storage() + if err != nil { + return nil, nil, err + } + + if storageRoot, err = sStorage.Root(); err != nil { + return nil, nil, err + } + + if err = closer(); err != nil { + return nil, nil, err + } + + classes, closer, err := s.classesTrie() + if err != nil { + return nil, nil, err + } + + if classesRoot, err = classes.Root(); err != nil { + return nil, nil, err + } + + if err = closer(); err != nil { + return nil, nil, err + } + + return storageRoot, classesRoot, nil +} diff --git a/core/trie/key.go b/core/trie/key.go index 7f0e6af609..7ca6c1cc32 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -28,6 +28,9 @@ func (k *Key) SubKey(n uint8) (*Key, error) { if n > k.len { return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len)) } + if n == k.len { + return &Key{}, nil + } newKey := &Key{len: n} copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:mnd diff --git a/core/trie/node.go b/core/trie/node.go index b62db62807..c56dde3603 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -18,7 +18,7 @@ type Node struct { } // Hash calculates the hash of a [Node] -func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) Hash(path *Key, hashFunc HashFunc) *felt.Felt { if path.Len() == 0 { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -33,7 +33,7 @@ func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { } // Hash calculates the hash of a [Node] -func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc HashFunc) *felt.Felt { path := path(nodeKey, parnetKey) return n.Hash(&path, hashFunc) } diff --git a/core/trie/proof.go b/core/trie/proof.go index 517ae60764..008e0d6c08 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -13,7 +13,7 @@ var ( ) type ProofNode interface { - Hash(hash hashFunc) *felt.Felt + Hash(hash HashFunc) *felt.Felt Len() uint8 PrettyPrint() } @@ -23,7 +23,7 @@ type Binary struct { RightHash *felt.Felt } -func (b *Binary) Hash(hash hashFunc) *felt.Felt { +func (b *Binary) Hash(hash HashFunc) *felt.Felt { return hash(b.LeftHash, b.RightHash) } @@ -42,7 +42,7 @@ type Edge struct { Path *Key // path from parent to child } -func (e *Edge) Hash(hash hashFunc) *felt.Felt { +func (e *Edge) Hash(hash HashFunc) *felt.Felt { length := make([]byte, len(e.Path.bitset)) length[len(e.Path.bitset)-1] = e.Path.len pathFelt := e.Path.Felt() @@ -199,7 +199,7 @@ func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Fe // merges paths in the specified order [commonNodes..., leftNodes..., rightNodes...] // ordering of the merged path is not important // since SplitProofPath can discover the left and right paths using the merged path and the rootHash -func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNode, *felt.Felt, error) { +func MergeProofPaths(leftPath, rightPath []ProofNode, hash HashFunc) ([]ProofNode, *felt.Felt, error) { merged := []ProofNode{} minLen := min(len(leftPath), len(rightPath)) @@ -236,7 +236,7 @@ func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNod // SplitProofPath splits the merged proof path into two paths (left and right), which were merged before // it first validates that the merged path is not circular, the split happens at most once and rootHash exists // then calls traverseNodes to split the path to left and right paths -func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash hashFunc) ([]ProofNode, []ProofNode, error) { +func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash HashFunc) ([]ProofNode, []ProofNode, error) { commonPath := []ProofNode{} leftPath := []ProofNode{} rightPath := []ProofNode{} @@ -316,7 +316,7 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { // verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes` // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 -func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { +func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash HashFunc) bool { expectedHash := root remainingPath := NewKey(key.len, key.bitset[:]) for i, proofNode := range proofs { @@ -340,12 +340,12 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode // Todo: // If we are verifying the key doesn't exist, then we should - // update subKey to point in the other direction + // update.Status subKey to point in the other direction if value == nil && i == len(proofs)-1 { return true } - if !proofNode.Path.Equal(subKey) { + if !proofNode.Path.Equal(subKey) && !subKey.Equal(&Key{}) { return false } expectedHash = proofNode.Child @@ -363,7 +363,7 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode // and therefore it's hash won't match the expected root. // ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt, - proofs [2][]ProofNode, hash hashFunc, + proofs [2][]ProofNode, hash HashFunc, ) (bool, error) { // Step 0: checks if len(keys) != len(values) { @@ -440,7 +440,7 @@ func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error { } // compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key -func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) { +func compressNode(idx int, proofNodes []ProofNode, hashF HashFunc) (int, uint8, error) { parent := proofNodes[idx] if idx == len(proofNodes)-1 { @@ -474,7 +474,7 @@ func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, } func assignChild(i, compressedParent int, parentNode *Node, - nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF hashFunc, + nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF HashFunc, ) (*Key, error) { childInd := i + compressedParent + 1 childKey, err := getChildKey(childInd, parentKey, leafKey, nilKey, proofNodes, hashF) @@ -494,7 +494,7 @@ func assignChild(i, compressedParent int, parentNode *Node, // ProofToPath returns a set of storage nodes from the root to the end of the proof path. // The storage nodes will have the hashes of the children, but only the key of the child // along the path outlined by the proof. -func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) { +func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF HashFunc) ([]StorageNode, error) { pathNodes := []StorageNode{} // Child keys that can't be derived are set to nilKey, so that we can store the node @@ -552,7 +552,7 @@ func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]Storag return pathNodes, nil } -func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool { +func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF HashFunc) bool { lastNode := pathNodes[len(pathNodes)-1].node noLeftMatch, noRightMatch := false, false if lastNode.LeftHash != nil && !pNode.Hash(hashF).Equal(lastNode.LeftHash) { @@ -607,7 +607,7 @@ func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key, return crntKey, err } -func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF hashFunc) (*Key, error) { +func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF HashFunc) (*Key, error) { if childIdx > len(proofNodes)-1 { return nilKey, nil } diff --git a/core/trie/trie.go b/core/trie/trie.go index c03357d3af..28c75fd891 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -13,7 +13,7 @@ import ( "github.com/NethermindEth/juno/db" ) -type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt +type HashFunc func(*felt.Felt, *felt.Felt) *felt.Felt // Trie is a dense Merkle Patricia Trie (i.e., all internal nodes have two children). // @@ -37,7 +37,7 @@ type Trie struct { rootKey *Key maxKey *felt.Felt storage *Storage - hash hashFunc + hash HashFunc dirtyNodes []*Key rootKeyIsDirty bool @@ -53,7 +53,7 @@ func NewTriePoseidon(storage *Storage, height uint8) (*Trie, error) { return newTrie(storage, height, crypto.Poseidon) } -func newTrie(storage *Storage, height uint8, hash hashFunc) (*Trie, error) { +func newTrie(storage *Storage, height uint8, hash HashFunc) (*Trie, error) { if height > felt.Bits { return nil, fmt.Errorf("max trie height is %d, got: %d", felt.Bits, height) } @@ -668,6 +668,10 @@ func (t *Trie) RootKey() *Key { return t.rootKey } +func (t *Trie) HashFunc() HashFunc { + return t.hash +} + func (t *Trie) Dump() { t.dump(0, nil) } diff --git a/mocks/mock_blockchain.go b/mocks/mock_blockchain.go index 8d6bf6045d..b90e9eff84 100644 --- a/mocks/mock_blockchain.go +++ b/mocks/mock_blockchain.go @@ -163,6 +163,22 @@ func (mr *MockReaderMockRecorder) HeadState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadState", reflect.TypeOf((*MockReader)(nil).HeadState)) } +// HeadTrie mocks base method. +func (m *MockReader) HeadTrie() (core.TrieReader, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HeadTrie") + ret0, _ := ret[0].(core.TrieReader) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// HeadTrie indicates an expected call of HeadTrie. +func (mr *MockReaderMockRecorder) HeadTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadTrie", reflect.TypeOf((*MockReader)(nil).HeadTrie)) +} + // HeadsHeader mocks base method. func (m *MockReader) HeadsHeader() (*core.Header, error) { m.ctrl.T.Helper() diff --git a/mocks/mock_trie.go b/mocks/mock_trie.go new file mode 100644 index 0000000000..570a055c4e --- /dev/null +++ b/mocks/mock_trie.go @@ -0,0 +1,104 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/NethermindEth/juno/core (interfaces: TrieReader) +// +// Generated by this command: +// +// mockgen -destination=../mocks/mock_trie.go -package=mocks github.com/NethermindEth/juno/core TrieReader +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + felt "github.com/NethermindEth/juno/core/felt" + trie "github.com/NethermindEth/juno/core/trie" + gomock "go.uber.org/mock/gomock" +) + +// MockTrieReader is a mock of TrieReader interface. +type MockTrieReader struct { + ctrl *gomock.Controller + recorder *MockTrieReaderMockRecorder +} + +// MockTrieReaderMockRecorder is the mock recorder for MockTrieReader. +type MockTrieReaderMockRecorder struct { + mock *MockTrieReader +} + +// NewMockTrieReader creates a new mock instance. +func NewMockTrieReader(ctrl *gomock.Controller) *MockTrieReader { + mock := &MockTrieReader{ctrl: ctrl} + mock.recorder = &MockTrieReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrieReader) EXPECT() *MockTrieReaderMockRecorder { + return m.recorder +} + +// ClassTrie mocks base method. +func (m *MockTrieReader) ClassTrie() (*trie.Trie, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClassTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ClassTrie indicates an expected call of ClassTrie. +func (mr *MockTrieReaderMockRecorder) ClassTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClassTrie", reflect.TypeOf((*MockTrieReader)(nil).ClassTrie)) +} + +// StateAndClassRoot mocks base method. +func (m *MockTrieReader) StateAndClassRoot() (*felt.Felt, *felt.Felt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StateAndClassRoot") + ret0, _ := ret[0].(*felt.Felt) + ret1, _ := ret[1].(*felt.Felt) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// StateAndClassRoot indicates an expected call of StateAndClassRoot. +func (mr *MockTrieReaderMockRecorder) StateAndClassRoot() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateAndClassRoot", reflect.TypeOf((*MockTrieReader)(nil).StateAndClassRoot)) +} + +// StorageTrie mocks base method. +func (m *MockTrieReader) StorageTrie() (*trie.Trie, func() error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StorageTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(func() error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// StorageTrie indicates an expected call of StorageTrie. +func (mr *MockTrieReaderMockRecorder) StorageTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrie", reflect.TypeOf((*MockTrieReader)(nil).StorageTrie)) +} + +// StorageTrieForAddr mocks base method. +func (m *MockTrieReader) StorageTrieForAddr(arg0 *felt.Felt) (*trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StorageTrieForAddr", arg0) + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StorageTrieForAddr indicates an expected call of StorageTrieForAddr. +func (mr *MockTrieReaderMockRecorder) StorageTrieForAddr(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrieForAddr", reflect.TypeOf((*MockTrieReader)(nil).StorageTrieForAddr), arg0) +} diff --git a/rpc/contract.go b/rpc/contract.go index e0eb3b7c2a..ba7de93029 100644 --- a/rpc/contract.go +++ b/rpc/contract.go @@ -27,22 +27,3 @@ func (h *Handler) Nonce(id BlockID, address felt.Felt) (*felt.Felt, *jsonrpc.Err return nonce, nil } - -// StorageAt gets the value of the storage at the given address and key. -// -// It follows the specification defined here: -// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L110 -func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *jsonrpc.Error) { - stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) - if rpcErr != nil { - return nil, rpcErr - } - defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageAt") - - value, err := stateReader.ContractStorage(&address, &key) - if err != nil { - return nil, ErrContractNotFound - } - - return value, nil -} diff --git a/rpc/contract_test.go b/rpc/contract_test.go index 8f9e100aa3..522ab0bb62 100644 --- a/rpc/contract_test.go +++ b/rpc/contract_test.go @@ -86,85 +86,3 @@ func TestNonce(t *testing.T) { assert.Equal(t, expectedNonce, nonce) }) } - -func TestStorageAt(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockReader := mocks.NewMockReader(mockCtrl) - log := utils.NewNopZapLogger() - handler := rpc.New(mockReader, nil, nil, "", log) - - t.Run("empty blockchain", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - mockState := mocks.NewMockStateHistoryReader(mockCtrl) - - t.Run("non-existent contract", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, errors.New("non-existent contract")) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - t.Run("non-existent key", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(&felt.Zero, errors.New("non-existent key")) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - expectedStorage := new(felt.Felt).SetUint64(1) - - t.Run("blockID - latest", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) - - t.Run("blockID - hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) - - t.Run("blockID - number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) -} diff --git a/rpc/handlers.go b/rpc/handlers.go index 3ae27684c7..5bb14f0a48 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -58,6 +58,11 @@ var ( // These errors can be only be returned by Juno-specific methods. ErrSubscriptionNotFound = &jsonrpc.Error{Code: 100, Message: "Subscription not found"} + + ErrStorageProofNotSupported = &jsonrpc.Error{ + Code: 42, + Message: "the node doesn't support storage proofs for blocks that are too far in the past. Use 'latest' as block id", + } ) const ( @@ -233,6 +238,13 @@ func (h *Handler) Methods() ([]jsonrpc.Method, string) { //nolint: funlen Params: []jsonrpc.Parameter{{Name: "contract_address"}, {Name: "key"}, {Name: "block_id"}}, Handler: h.StorageAt, }, + { + Name: "starknet_getStorageProof", + Params: []jsonrpc.Parameter{ + {Name: "block_id"}, {Name: "classes", Optional: true}, {Name: "contracts", Optional: true}, {Name: "storage_keys", Optional: true}, + }, + Handler: h.StorageProof, + }, { Name: "starknet_getClassHashAt", Params: []jsonrpc.Parameter{{Name: "block_id"}, {Name: "contract_address"}}, diff --git a/rpc/storage.go b/rpc/storage.go new file mode 100644 index 0000000000..1c9e1fa1cf --- /dev/null +++ b/rpc/storage.go @@ -0,0 +1,297 @@ +package rpc + +import ( + "errors" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/jsonrpc" +) + +/**************************************************** + Storage Handlers +*****************************************************/ + +// StorageAt gets the value of the storage at the given address and key. +// +// It follows the specification defined here: +// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L110 +func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *jsonrpc.Error) { + stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) + if rpcErr != nil { + return nil, rpcErr + } + defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageAt") + + value, err := stateReader.ContractStorage(&address, &key) + if err != nil { + return nil, ErrContractNotFound + } + + return value, nil +} + +// StorageProof returns the merkle paths in one of the state tries: global state, classes, individual contract +// +// It follows the specification defined here: +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L910 +func (h *Handler) StorageProof( + id BlockID, + classes, contracts []felt.Felt, + storageKeys []StorageKeys, +) (*StorageProofResult, *jsonrpc.Error) { + if !id.Latest { + return nil, ErrStorageProofNotSupported + } + + stateReader, stateCloser, err := h.bcReader.HeadState() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageProof") + + head, err := h.bcReader.Head() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + trieReader, stateCloser2, err := h.bcReader.HeadTrie() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + defer h.callAndLogErr(stateCloser2, "Error closing trie reader in getStorageProof") + + storageRoot, classRoot, err := trieReader.StateAndClassRoot() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + result := &StorageProofResult{ + GlobalRoots: &GlobalRoots{ + ContractsTreeRoot: storageRoot, + ClassesTreeRoot: classRoot, + BlockHash: head.Hash, + }, + } + + result.ClassesProof, err = getClassesProof(trieReader, classes) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + result.ContractsProof, err = getContractsProof(stateReader, trieReader, contracts) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + result.ContractsStorageProofs, err = getContractsStorageProofs(trieReader, storageKeys) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + return result, nil +} + +// StorageKeys represents an item in `contracts_storage_keys. parameter +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L938 +type StorageKeys struct { + Contract felt.Felt `json:"contract_address"` + Keys []felt.Felt `json:"storage_keys"` +} + +// MerkleNode represents a proof node in a trie +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3632 +// Implemented by MerkleBinaryNode, MerkleEdgeNode +type MerkleNode interface { + AsProofNode() trie.ProofNode +} + +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3644 +type MerkleBinaryNode struct { + Left *felt.Felt `json:"left"` + Right *felt.Felt `json:"right"` +} + +func (mbn *MerkleBinaryNode) AsProofNode() trie.ProofNode { + return &trie.Binary{ + LeftHash: mbn.Left, + RightHash: mbn.Right, + } +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L3720 +type MerkleEdgeNode struct { + Path string `json:"path"` + Length int `json:"length"` + Child *felt.Felt `json:"child"` +} + +func (men *MerkleEdgeNode) AsProofNode() trie.ProofNode { + f, _ := new(felt.Felt).SetString(men.Path) + pbs := f.Bytes() + path := trie.NewKey(uint8(men.Length), pbs[:]) + + return &trie.Edge{ + Path: &path, + Child: men.Child, + } +} + +// HashToNode represents an item in `NODE_HASH_TO_NODE_MAPPING` specified here +// https://github.com/starkware-libs/starknet-specs/blob/647caa00c0223e1daab1b2f3acc4e613ba2138aa/api/starknet_api_openrpc.json#L3667 +type HashToNode struct { + Hash *felt.Felt `json:"node_hash"` + Node MerkleNode `json:"node"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L986 +type LeafData struct { + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L979 +type ContractProof struct { + Nodes []*HashToNode `json:"nodes"` + LeavesData []*LeafData `json:"contract_leaves_data"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L1011 +type GlobalRoots struct { + ContractsTreeRoot *felt.Felt `json:"contracts_tree_root"` + ClassesTreeRoot *felt.Felt `json:"classes_tree_root"` + BlockHash *felt.Felt `json:"block_hash"` +} + +// https://github.com/starkware-libs/starknet-specs/blob/8cf463b79ba1dd876f67c7f637e5ea48beb07b5b/api/starknet_api_openrpc.json#L970 +type StorageProofResult struct { + ClassesProof []*HashToNode `json:"classes_proof"` + ContractsProof *ContractProof `json:"contracts_proof"` + ContractsStorageProofs [][]*HashToNode `json:"contracts_storage_proofs"` + GlobalRoots *GlobalRoots `json:"global_roots"` +} + +func getClassesProof(reader core.TrieReader, classes []felt.Felt) ([]*HashToNode, error) { + cTrie, _, err := reader.ClassTrie() + if err != nil { + return nil, err + } + result := []*HashToNode{} + for _, class := range classes { + nodes, err := getProof(cTrie, &class) + if err != nil { + return nil, err + } + result = append(result, nodes...) + } + return result, nil +} + +func getContractsProof(stReader core.StateReader, trReader core.TrieReader, contracts []felt.Felt) (*ContractProof, error) { + sTrie, _, err := trReader.StorageTrie() + if err != nil { + return nil, err + } + + result := &ContractProof{ + Nodes: []*HashToNode{}, + LeavesData: make([]*LeafData, 0, len(contracts)), + } + + for _, contract := range contracts { + leafData, err := getLeafData(stReader, &contract) + if err != nil { + return nil, err + } + result.LeavesData = append(result.LeavesData, leafData) + + nodes, err := getProof(sTrie, &contract) + if err != nil { + return nil, err + } + result.Nodes = append(result.Nodes, nodes...) + } + + return result, nil +} + +func getLeafData(reader core.StateReader, contract *felt.Felt) (*LeafData, error) { + nonce, err := reader.ContractNonce(contract) + if errors.Is(err, db.ErrKeyNotFound) { + return nil, nil + } + if err != nil { + return nil, err + } + classHash, err := reader.ContractClassHash(contract) + if err != nil { + return nil, err + } + + return &LeafData{ + Nonce: nonce, + ClassHash: classHash, + }, nil +} + +func getContractsStorageProofs(reader core.TrieReader, keys []StorageKeys) ([][]*HashToNode, error) { + result := make([][]*HashToNode, 0, len(keys)) + + for _, key := range keys { + csTrie, err := reader.StorageTrieForAddr(&key.Contract) + if err != nil { + // Note: if contract does not exist, `StorageTrieForAddr()` returns an empty trie, not an error + return nil, err + } + + nodes := []*HashToNode{} + for _, slot := range key.Keys { + proof, err := getProof(csTrie, &slot) + if err != nil { + return nil, err + } + nodes = append(nodes, proof...) + } + result = append(result, nodes) + } + + return result, nil +} + +func getProof(t *trie.Trie, elt *felt.Felt) ([]*HashToNode, error) { + feltBytes := elt.Bytes() + key := trie.NewKey(core.ContractStorageTrieHeight, feltBytes[:]) + nodes, err := trie.GetProof(&key, t) + if err != nil { + return nil, err + } + + // adapt proofs to the expected format + hashNodes := make([]*HashToNode, len(nodes)) + for i, node := range nodes { + var merkle MerkleNode + + if binary, ok := node.(*trie.Binary); ok { + merkle = &MerkleBinaryNode{ + Left: binary.LeftHash, + Right: binary.RightHash, + } + } + if edge, ok := node.(*trie.Edge); ok { + path := edge.Path + f := path.Felt() + merkle = &MerkleEdgeNode{ + Path: f.String(), + Length: int(edge.Len()), + Child: edge.Child, + } + } + + hashNodes[i] = &HashToNode{ + Hash: node.Hash(t.HashFunc()), + Node: merkle, + } + } + + return hashNodes, nil +} diff --git a/rpc/storage_test.go b/rpc/storage_test.go new file mode 100644 index 0000000000..40f22991c8 --- /dev/null +++ b/rpc/storage_test.go @@ -0,0 +1,472 @@ +package rpc_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/clients/feeder" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/mocks" + "github.com/NethermindEth/juno/rpc" + adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" + "github.com/NethermindEth/juno/sync" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestStorageAt(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + t.Run("empty blockchain", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + + t.Run("non-existent contract", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, errors.New("non-existent contract")) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + t.Run("non-existent key", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(&felt.Zero, errors.New("non-existent key")) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + expectedStorage := new(felt.Felt).SetUint64(1) + + t.Run("blockID - latest", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) + + t.Run("blockID - hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) + + t.Run("blockID - number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) +} + +func TestStorageProof(t *testing.T) { + // dummy values + var ( + blkHash = utils.HexToFelt(t, "0x11ead") + clsRoot = utils.HexToFelt(t, "0xc1a55") + stgRoot = utils.HexToFelt(t, "0xc0ffee") + key = new(felt.Felt).SetUint64(1) + noSuchKey = new(felt.Felt).SetUint64(0) + value = new(felt.Felt).SetUint64(51) + blockLatest = rpc.BlockID{Latest: true} + blockNumber = uint64(1313) + nopCloser = func() error { + return nil + } + ) + + tempTrie := emptyTrie(t) + + _, err := tempTrie.Put(key, value) + require.NoError(t, err) + _, err = tempTrie.Put(new(felt.Felt).SetUint64(8), new(felt.Felt).SetUint64(59)) + require.NoError(t, err) + require.NoError(t, tempTrie.Commit()) + + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockTrie := mocks.NewMockTrieReader(mockCtrl) + + mockReader.EXPECT().HeadState().Return(mockState, func() error { + return nil + }, nil).AnyTimes() + mockReader.EXPECT().HeadTrie().Return(mockTrie, func() error { return nil }, nil).AnyTimes() + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}}, nil).AnyTimes() + mockTrie.EXPECT().StateAndClassRoot().Return(stgRoot, clsRoot, nil).AnyTimes() + mockTrie.EXPECT().ClassTrie().Return(tempTrie, nopCloser, nil).AnyTimes() + mockTrie.EXPECT().StorageTrie().Return(tempTrie, nopCloser, nil).AnyTimes() + + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + verifyIf := func(proof []*rpc.HashToNode, key *felt.Felt, value *felt.Felt) { + root, err := tempTrie.Root() + require.NoError(t, err) + + pnodes := []trie.ProofNode{} + for _, hn := range proof { + pnodes = append(pnodes, hn.Node.AsProofNode()) + } + + kbs := key.Bytes() + kkey := trie.NewKey(251, kbs[:]) + require.True(t, trie.VerifyProof(root, &kkey, value, pnodes, tempTrie.HashFunc())) + } + + t.Run("Trie proofs sanity check", func(t *testing.T) { + kbs := key.Bytes() + kKey := trie.NewKey(251, kbs[:]) + proof, err := trie.GetProof(&kKey, tempTrie) + require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) + require.True(t, trie.VerifyProof(root, &kKey, value, proof, tempTrie.HashFunc())) + + // non-membership test + kbs = noSuchKey.Bytes() + kKey = trie.NewKey(251, kbs[:]) + proof, err = trie.GetProof(&kKey, tempTrie) + require.NoError(t, err) + require.True(t, trie.VerifyProof(root, &kKey, nil, proof, tempTrie.HashFunc())) + }) + t.Run("global roots are filled", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + require.Nil(t, rpcErr) + + require.NotNil(t, proof) + require.NotNil(t, proof.GlobalRoots) + require.Equal(t, blkHash, proof.GlobalRoots.BlockHash) + require.Equal(t, clsRoot, proof.GlobalRoots.ClassesTreeRoot) + require.Equal(t, stgRoot, proof.GlobalRoots.ContractsTreeRoot) + }) + t.Run("error is returned whenever not latest block is requested", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: 1}, nil, nil, nil) + assert.Equal(t, rpc.ErrStorageProofNotSupported, rpcErr) + require.Nil(t, proof) + }) + t.Run("error is returned even when blknum matches head", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: blockNumber}, nil, nil, nil) + assert.Equal(t, rpc.ErrStorageProofNotSupported, rpcErr) + require.Nil(t, proof) + }) + t.Run("empty request", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 0, 0, 0) + }) + t.Run("class trie hash does not exist in a trie", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*noSuchKey}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 0, 0, 0) + verifyIf(proof.ClassesProof, noSuchKey, nil) + }) + t.Run("class trie hash exists in a trie", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 0, 0, 0) + verifyIf(proof.ClassesProof, key, value) + }) + t.Run("storage trie address does not exist in a trie", func(t *testing.T) { + mockState.EXPECT().ContractNonce(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(0) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*noSuchKey}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 3, 1, 0) + require.Nil(t, proof.ContractsProof.LeavesData[0]) + + verifyIf(proof.ContractsProof.Nodes, noSuchKey, nil) + }) + t.Run("storage trie address exists in a trie", func(t *testing.T) { + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil).Times(1) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 3, 1, 0) + + require.NotNil(t, proof.ContractsProof.LeavesData[0]) + ld := proof.ContractsProof.LeavesData[0] + require.Equal(t, nonce, ld.Nonce) + require.Equal(t, classHasah, ld.ClassHash) + + verifyIf(proof.ContractsProof.Nodes, key, value) + }) + t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { + contract := utils.HexToFelt(t, "0xdead") + mockTrie.EXPECT().StorageTrieForAddr(contract).Return(emptyTrie(t), nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 0) + }) + //nolint:dupl + t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { + contract := utils.HexToFelt(t, "0xabcd") + mockTrie.EXPECT().StorageTrieForAddr(gomock.Any()).Return(tempTrie, nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*noSuchKey}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 3) + + verifyIf(proof.ContractsStorageProofs[0], noSuchKey, nil) + }) + //nolint:dupl + t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { + contract := utils.HexToFelt(t, "0xabcd") + mockTrie.EXPECT().StorageTrieForAddr(gomock.Any()).Return(tempTrie, nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 3) + + verifyIf(proof.ContractsStorageProofs[0], key, value) + }) + t.Run("class & storage tries proofs requested", func(t *testing.T) { + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 3, 1, 0) + }) +} + +func arityTest(t *testing.T, + proof *rpc.StorageProofResult, + classesProofArity int, + contractsProofNodesArity int, + contractsProofLeavesArity int, + contractStorageArity int, +) { + require.Len(t, proof.ClassesProof, classesProofArity) + require.Len(t, proof.ContractsStorageProofs, contractStorageArity) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, contractsProofNodesArity) + require.Len(t, proof.ContractsProof.LeavesData, contractsProofLeavesArity) +} + +func emptyTrie(t *testing.T) *trie.Trie { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + return tempTrie +} + +func TestStorageRoots(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + client := feeder.NewTestClient(t, &utils.Mainnet) + gw := adaptfeeder.New(client) + + log := utils.NewNopZapLogger() + testDB := pebble.NewMemTest(t) + bc := blockchain.New(testDB, &utils.Mainnet) + synchronizer := sync.New(bc, gw, log, time.Duration(0), false) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + + require.NoError(t, synchronizer.Run(ctx)) + cancel() + + var ( + expectedBlockHash = utils.HexToFelt(t, "0x4e1f77f39545afe866ac151ac908bd1a347a2a8a7d58bef1276db4f06fdf2f6") + expectedGlobalRoot = utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9") + expectedClsRoot = utils.HexToFelt(t, "0x0") + expectedStgRoot = utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9") + expectedContractAddress = utils.HexToFelt(t, "0x2d6c9569dea5f18628f1ef7c15978ee3093d2d3eec3b893aac08004e678ead3") + expectedContractLeaf = utils.HexToFelt(t, "0x7036d8dd68dc9539c6db8c88f72b1ab16e76d62b5f09118eca5ae78276b0ee4") + ) + + t.Run("sanity check - mainnet block 2", func(t *testing.T) { + expectedBlockNumber := uint64(2) + + blk, err := bc.Head() + assert.NoError(t, err) + assert.Equal(t, expectedBlockNumber, blk.Number) + assert.Equal(t, expectedBlockHash, blk.Hash, blk.Hash.String()) + assert.Equal(t, expectedGlobalRoot, blk.GlobalStateRoot, blk.GlobalStateRoot.String()) + }) + + t.Run("check class and storage roots matches the global", func(t *testing.T) { + reader, closer, err := bc.HeadTrie() + assert.NoError(t, err) + defer func() { _ = closer() }() + + stgRoot, clsRoot, err := reader.StateAndClassRoot() + assert.NoError(t, err) + + assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) + assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) + + verifyGlobalStateRoot(t, expectedGlobalRoot, clsRoot, stgRoot) + }) + + t.Run("check requested contract and storage slot exists", func(t *testing.T) { + trieReader, closer, err := bc.HeadTrie() + assert.NoError(t, err) + defer func() { _ = closer() }() + + sTrie, sCloser, err := trieReader.StorageTrie() + assert.NoError(t, err) + defer func() { _ = sCloser() }() + + leaf, err := sTrie.Get(expectedContractAddress) + assert.NoError(t, err) + assert.Equal(t, leaf, expectedContractLeaf, leaf.String()) + + stateReader, stCloser, err := bc.HeadState() + assert.NoError(t, err) + defer func() { _ = stCloser() }() + + clsHash, err := stateReader.ContractClassHash(expectedContractAddress) + assert.NoError(t, err) + assert.Equal(t, clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) + }) + + t.Run("get contract proof", func(t *testing.T) { + handler := rpc.New(bc, nil, nil, "", log) + result, rpcErr := handler.StorageProof( + rpc.BlockID{Latest: true}, nil, []felt.Felt{*expectedContractAddress}, nil) + require.Nil(t, rpcErr) + + expectedResult := rpc.StorageProofResult{ + ClassesProof: []*rpc.HashToNode{}, + ContractsStorageProofs: [][]*rpc.HashToNode{}, + ContractsProof: &rpc.ContractProof{ + LeavesData: []*rpc.LeafData{ + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), + }, + }, + Nodes: []*rpc.HashToNode{ + { + Hash: utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x4e1f289e55ac8a821fd463478e6f5543256beb934a871be91d00a0d3f2e7964"), + Right: utils.HexToFelt(t, "0x67d9833b51e7bf1cab0e71e68477bf7f0b704391d753f9d793008e4f6587c53"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x4e1f289e55ac8a821fd463478e6f5543256beb934a871be91d00a0d3f2e7964"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x1ef87d62309ff1cad58d39e8f5480f9caa9acd78a43f139d87220a1babe38a4"), + Right: utils.HexToFelt(t, "0x9a258d24b3aeb7e263e910d68a18d85305703a2f20df2e806ecbb1fb28760f"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x9a258d24b3aeb7e263e910d68a18d85305703a2f20df2e806ecbb1fb28760f"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x53f61d0cb8099e2e7ffc214c4ef7ac8520abb5327510f84affe90b1890d314c"), + Right: utils.HexToFelt(t, "0x45ca67f381dcd01fec774743a4aaed6b36e1bda979185cf5dce538ad0007914"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x53f61d0cb8099e2e7ffc214c4ef7ac8520abb5327510f84affe90b1890d314c"), + Node: &rpc.MerkleBinaryNode{ + Left: utils.HexToFelt(t, "0x17d6fc8431c48e41222a3ede441d1e2d91c31eb67a8aa9c030c99c510e9f34c"), + Right: utils.HexToFelt(t, "0x1cf95259ae39c038e87224fa5fdb7c7eeba6dd4263e05e80c9a8e27c3240f2c"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x1cf95259ae39c038e87224fa5fdb7c7eeba6dd4263e05e80c9a8e27c3240f2c"), + Node: &rpc.MerkleEdgeNode{ + Path: "0x56c9569dea5f18628f1ef7c15978ee3093d2d3eec3b893aac08004e678ead3", + Length: 247, + Child: expectedContractLeaf, + }, + }, + }, + }, + GlobalRoots: &rpc.GlobalRoots{ + BlockHash: expectedBlockHash, + ClassesTreeRoot: expectedClsRoot, + ContractsTreeRoot: expectedStgRoot, + }, + } + + assert.Equal(t, expectedResult, *result) + }) +} + +func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { + stateVersion := new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) + if classRoot.IsZero() { + assert.Equal(t, globalStateRoot, storageRoot) + } else { + assert.Equal(t, globalStateRoot, crypto.PoseidonArray(stateVersion, storageRoot, classRoot)) + } +}