Skip to content

Commit

Permalink
fix: cache .get_code() (#2480)
Browse files Browse the repository at this point in the history
Co-authored-by: antazoey <[email protected]>
  • Loading branch information
antazoey and antazoey authored Jan 28, 2025
1 parent 601e2c4 commit 9b4184d
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 25 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ norecursedirs = "projects"
# And 'pytest_ethereum' is not used and causes issues in some environments.
addopts = """
-p no:pytest_ethereum
-p no:boa_test
"""

python_files = "test_*.py"
Expand Down
5 changes: 2 additions & 3 deletions src/ape/api/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def code(self) -> "ContractCode":
"""
The raw bytes of the smart-contract code at the address.
"""

# TODO: Explore caching this (based on `self.provider.network` and examining code)
return self.provider.get_code(self.address)
# NOTE: Chain manager handles code caching.
return self.chain_manager.get_code(self.address)

@property
def codesize(self) -> int:
Expand Down
6 changes: 0 additions & 6 deletions src/ape/managers/_contractscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,6 @@ def get(
self.contract_types[address_key] = contract_type_to_cache
return contract_type_to_cache

if not self.provider.get_code(address_key):
if default:
self.contract_types[address_key] = default

return default

# Also gets cached to disk for faster lookup next time.
if fetch_from_explorer:
contract_type = self._get_contract_type_from_explorer(address_key)
Expand Down
29 changes: 26 additions & 3 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
if TYPE_CHECKING:
from rich.console import Console as RichConsole

from ape.types.trace import GasReport, SourceTraceback
from ape.types.vm import SnapshotID
from ape.types import BlockID, ContractCode, GasReport, SnapshotID, SourceTraceback


class BlockContainer(BaseManager):
Expand Down Expand Up @@ -703,7 +702,7 @@ def _get_console(self, *args, **kwargs):
class ChainManager(BaseManager):
"""
A class for managing the state of the active blockchain.
Also handy for querying data about the chain and managing local caches.
Also, handy for querying data about the chain and managing local caches.
Access the chain manager singleton from the root ``ape`` namespace.
Usage example::
Expand All @@ -716,6 +715,7 @@ class ChainManager(BaseManager):
_block_container_map: dict[int, BlockContainer] = {}
_transaction_history_map: dict[int, TransactionHistory] = {}
_reports: ReportManager = ReportManager()
_code: dict[str, dict[str, dict[AddressType, "ContractCode"]]] = {}

@cached_property
def contracts(self) -> ContractCache:
Expand Down Expand Up @@ -965,3 +965,26 @@ def get_receipt(self, transaction_hash: str) -> ReceiptAPI:
raise TransactionNotFoundError(transaction_hash=transaction_hash)

return receipt

def get_code(
self, address: AddressType, block_id: Optional["BlockID"] = None
) -> "ContractCode":
network = self.provider.network

# Two reasons to avoid caching:
# 1. dev networks - chain isolation makes this mess up
# 2. specifying block_id= kwarg - likely checking if code
# exists at the time and shouldn't use cache.
skip_cache = network.is_dev or block_id is not None
if skip_cache:
return self.provider.get_code(address, block_id=block_id)

self._code.setdefault(network.ecosystem.name, {})
self._code[network.ecosystem.name].setdefault(network.name, {})
if address in self._code[network.ecosystem.name][network.name]:
return self._code[network.ecosystem.name][network.name][address]

# Get from RPC for the first time AND use cache.
code = self.provider.get_code(address)
self._code[network.ecosystem.name][network.name][address] = code
return code
17 changes: 10 additions & 7 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def encode_contract_blueprint(
)

def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfo]:
contract_code = self.provider.get_code(address)
contract_code = self.chain_manager.get_code(address)
if isinstance(contract_code, bytes):
contract_code = to_hex(contract_code)

Expand Down Expand Up @@ -1150,12 +1150,15 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
except KeyError:
name = call["method_id"]
else:
assert isinstance(method_abi, MethodABI) # For mypy

# Check if method name duplicated. If that is the case, use selector.
times = len([x for x in contract_type.methods if x.name == method_abi.name])
name = (method_abi.name if times == 1 else method_abi.selector) or call["method_id"]
call = self._enrich_calldata(call, method_abi, **kwargs)
if isinstance(method_abi, MethodABI):
# Check if method name duplicated. If that is the case, use selector.
times = len([x for x in contract_type.methods if x.name == method_abi.name])
name = (method_abi.name if times == 1 else method_abi.selector) or call[
"method_id"
]
call = self._enrich_calldata(call, method_abi, **kwargs)
else:
name = call.get("method_id") or "0x"
else:
name = call.get("method_id") or "0x"

Expand Down
6 changes: 3 additions & 3 deletions src/ape_ethereum/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ def perform_contract_creation_query(
Find when a contract was deployed using binary search and block tracing.
"""
# skip the search if there is still no code at address at head
if not self.provider.get_code(query.contract):
if not self.chain_manager.get_code(query.contract):
return None

def find_creation_block(lo, hi):
# perform a binary search to find the block when the contract was deployed.
# takes log2(height), doesn't work with contracts that have been reinit.
while hi - lo > 1:
mid = (lo + hi) // 2
code = self.provider.get_code(query.contract, block_id=mid)
code = self.chain_manager.get_code(query.contract, block_id=mid)
if not code:
lo = mid
else:
hi = mid

if self.provider.get_code(query.contract, block_id=hi):
if self.chain_manager.get_code(query.contract, block_id=hi):
return hi

return None
Expand Down
7 changes: 6 additions & 1 deletion src/ape_test/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI":
account = self.init_test_account(
new_index, generated_account.address, generated_account.private_key
)
self.generated_accounts.append(account)

# Only cache if being created outside the expected number of accounts.
# Else, ends up cached twice and caused logic problems elsewhere.
if new_index >= self.number_of_accounts:
self.generated_accounts.append(account)

return account

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def create_mock_sepolia(ethereum, eth_tester_provider, vyper_contract_instance):
@contextmanager
def fn():
# Ensuring contract exists before hack.
# This allow the network to be past genesis which is more realistic.
# This allows the network to be past genesis which is more realistic.
_ = vyper_contract_instance
eth_tester_provider.network.name = "sepolia"
yield eth_tester_provider.network
Expand Down
18 changes: 18 additions & 0 deletions tests/functional/geth/test_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from tests.conftest import geth_process_test


@geth_process_test
def test_get_code(mocker, chain, geth_contract, mock_sepolia):
# NOTE: Using mock_sepolia because code doesn't get cached in local networks.
actual = chain.get_code(geth_contract.address)
expected = chain.provider.get_code(geth_contract.address)
assert actual == expected

# Ensure uses cache (via not using provider).
provider_spy = mocker.spy(chain.provider.web3.eth, "get_code")
_ = chain.get_code(geth_contract.address)
assert provider_spy.call_count == 0

# block_id test, cache should interfere.
actual_2 = chain.get_code(geth_contract.address, block_id=0)
assert not actual_2 # Doesn't exist at block 0.
2 changes: 1 addition & 1 deletion tests/functional/geth/test_contracts_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_contract_type(address, *args, **kwargs):
raise ValueError("Fake explorer only knows about proxy and target contracts.")

with create_mock_sepolia() as network:
# Setup our network to use our fake explorer.
# Set up our network to use our fake explorer.
mock_explorer.get_contract_type.side_effect = get_contract_type
network.__dict__["explorer"] = mock_explorer

Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def test_isolate_in_tempdir_does_not_alter_sources(project):
# First, create a bad source.
with project.temp_config(contracts_folder="build"):
new_src = project.contracts_folder / "newsource.json"
new_src.parent.mkdir(exist_ok=True, parents=True)
new_src.write_text("this is not json, oops")
project.sources.refresh() # Only need to be called when run with other tests.

Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


def test_minimal_proxy(ethereum, minimal_proxy, chain):
chain.provider.network.__dict__["explorer"] = None # Ensure no explorer, messes up test.
actual = ethereum.get_proxy_info(minimal_proxy.address)
assert actual is not None
assert actual.type == ProxyType.Minimal
Expand Down

0 comments on commit 9b4184d

Please sign in to comment.