From 4c5b8538e5e9bc84e0964ea77ebbe98d29872316 Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 10 Oct 2023 16:44:22 -0500 Subject: [PATCH] feat: use OTS namespace for get contract creation receipt in ape-geth [APE-1453] (#1697) --- src/ape/api/providers.py | 2 + src/ape_geth/__init__.py | 6 +++ src/ape_geth/provider.py | 60 +++++++++++++++++++++++++---- src/ape_geth/query.py | 35 +++++++++++++++++ tests/functional/geth/test_query.py | 31 +++++++++++++++ 5 files changed, 127 insertions(+), 7 deletions(-) create mode 100644 src/ape_geth/query.py create mode 100644 tests/functional/geth/test_query.py diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 5e52f79671..b3399c0230 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -1358,6 +1358,7 @@ def get_contract_creation_receipts( # TODO: Handle when code is nonzero but doesn't match # TODO: Handle when code is empty after it's not (re-init) elif HexBytes(self.get_code(address, block_id=mid_block)) == contract_code: + # If the code exists, we need to look backwards. yield from self.get_contract_creation_receipts( address, start_block=start_block, @@ -1366,6 +1367,7 @@ def get_contract_creation_receipts( ) elif mid_block + 1 <= stop_block: + # The code does not exist yet, we need to look ahead. yield from self.get_contract_creation_receipts( address, start_block=mid_block + 1, diff --git a/src/ape_geth/__init__.py b/src/ape_geth/__init__.py index 93efbd0009..05b66fe3bd 100644 --- a/src/ape_geth/__init__.py +++ b/src/ape_geth/__init__.py @@ -3,6 +3,7 @@ from .provider import Geth as GethProvider from .provider import GethConfig, GethDev, GethNetworkConfig +from .query import OTSQueryEngine @plugins.register(plugins.Config) @@ -18,3 +19,8 @@ def providers(): yield "ethereum", network_name, GethProvider yield "ethereum", LOCAL_NETWORK_NAME, GethDev + + +@plugins.register(plugins.QueryPlugin) +def query_engines(): + yield OTSQueryEngine diff --git a/src/ape_geth/provider.py b/src/ape_geth/provider.py index 2d414c130f..872a85a879 100644 --- a/src/ape_geth/provider.py +++ b/src/ape_geth/provider.py @@ -3,6 +3,7 @@ import shutil import sys from abc import ABC +from functools import cached_property from itertools import tee from pathlib import Path from subprocess import DEVNULL, PIPE, Popen @@ -38,15 +39,21 @@ from ape._pydantic_compat import Extra from ape.api import ( PluginConfig, + ReceiptAPI, SubprocessProvider, TestProviderAPI, TransactionAPI, UpstreamProvider, Web3Provider, ) -from ape.exceptions import APINotImplementedError, ProviderError +from ape.exceptions import ( + ApeException, + APINotImplementedError, + ContractNotFoundError, + ProviderError, +) from ape.logging import LogLevel, logger -from ape.types import CallTreeNode, SnapshotID, SourceTraceback, TraceFrame +from ape.types import AddressType, CallTreeNode, SnapshotID, SourceTraceback, TraceFrame from ape.utils import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_CHAIN_ID, @@ -290,23 +297,43 @@ def data_dir(self) -> Path: return _get_default_data_dir() + @cached_property + def _ots_api_level(self) -> Optional[int]: + # NOTE: Returns None when OTS namespace is not enabled. + try: + result = self._make_request("ots_getApiLevel") + except (NotImplementedError, ApeException, ValueError): + return None + + if isinstance(result, int): + return result + + elif isinstance(result, str) and result.isnumeric(): + return int(result) + + return None + def _set_web3(self): self._client_version = None # Clear cached version when connecting to another URI. self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path) def _complete_connect(self): - if "geth" in self.client_version.lower(): + client_version = self.client_version.lower() + if "geth" in client_version: self._log_connection("Geth") - elif "erigon" in self.client_version.lower(): + elif "reth" in client_version: + self._log_connection("Reth") + elif "erigon" in client_version: self._log_connection("Erigon") self.concurrency = 8 self.block_page_size = 40_000 - elif "nethermind" in self.client_version.lower(): + elif "nethermind" in client_version: self._log_connection("Nethermind") self.concurrency = 32 self.block_page_size = 50_000 else: - client_name = self.client_version.split("/")[0] + client_name = client_version.split("/")[0] + logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.") logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.") self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) @@ -399,7 +426,26 @@ def _log_connection(self, client_name: str): ) logger.info(f"{msg} {suffix}.") - def _make_request(self, endpoint: str, parameters: List) -> Any: + def ots_get_contract_creator(self, address: AddressType) -> Optional[Dict]: + if self._ots_api_level is None: + return None + + result = self._make_request("ots_getContractCreator", [address]) + if result is None: + # NOTE: Skip the explorer part of the error message via `has_explorer=True`. + raise ContractNotFoundError(address, has_explorer=True, provider_name=self.name) + + return result + + def _get_contract_creation_receipt(self, address: AddressType) -> Optional[ReceiptAPI]: + if result := self.ots_get_contract_creator(address): + tx_hash = result["hash"] + return self.get_receipt(tx_hash) + + return None + + def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: + parameters = parameters or [] try: return super()._make_request(endpoint, parameters) except ProviderError as err: diff --git a/src/ape_geth/query.py b/src/ape_geth/query.py new file mode 100644 index 0000000000..98d6de374f --- /dev/null +++ b/src/ape_geth/query.py @@ -0,0 +1,35 @@ +from functools import singledispatchmethod +from typing import Iterator, Optional + +from ape.api import ReceiptAPI +from ape.api.query import ContractCreationQuery, QueryAPI, QueryType +from ape.exceptions import QueryEngineError +from ape_geth.provider import BaseGethProvider + + +class OTSQueryEngine(QueryAPI): + @singledispatchmethod + def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore[override] + return None + + @singledispatchmethod + def perform_query(self, query: QueryType) -> Iterator: # type: ignore[override] + raise QueryEngineError( + f"{self.__class__.__name__} cannot handle {query.__class__.__name__} queries." + ) + + @estimate_query.register + def estimate_contract_creation_query(self, query: ContractCreationQuery) -> Optional[int]: + if provider := self.network_manager.active_provider: + if not isinstance(provider, BaseGethProvider): + return None + elif uri := provider.http_uri: + return 225 if uri.startswith("http://") else 600 + + return None + + @perform_query.register + def get_contract_creation_receipt(self, query: ContractCreationQuery) -> Iterator[ReceiptAPI]: + if self.network_manager.active_provider and isinstance(self.provider, BaseGethProvider): + if receipt := self.provider._get_contract_creation_receipt(query.contract): + yield receipt diff --git a/tests/functional/geth/test_query.py b/tests/functional/geth/test_query.py new file mode 100644 index 0000000000..0416cb0543 --- /dev/null +++ b/tests/functional/geth/test_query.py @@ -0,0 +1,31 @@ +from typing import List, Tuple + +from ape.exceptions import ChainError +from tests.conftest import geth_process_test + + +@geth_process_test +def test_get_contract_creation_receipts(mock_geth, geth_contract, chain, networks, geth_provider): + geth_provider.__dict__["explorer"] = None + provider = networks.active_provider + networks.active_provider = mock_geth + mock_geth._web3.eth.get_block.side_effect = geth_provider.get_block + + try: + mock_geth._web3.eth.get_code.return_value = b"123" + + # NOTE: Due to mocks, this next part may not actually find the contract. + # but that is ok but we mostly want to make sure it tries OTS. There + # are other tests for the brute-force logic. + try: + next(chain.contracts.get_creation_receipt(geth_contract.address), None) + except ChainError: + pass + + # Ensure we tried using OTS. + actual = mock_geth._web3.provider.make_request.call_args + expected: Tuple[str, List] = ("ots_getApiLevel", []) + assert any(arguments == expected for arguments in actual) + + finally: + networks.active_provider = provider