Skip to content

Commit

Permalink
feat: use OTS namespace for get contract creation receipt in ape-geth…
Browse files Browse the repository at this point in the history
… [APE-1453] (#1697)
  • Loading branch information
antazoey authored Oct 10, 2023
1 parent 69529fe commit 4c5b853
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,7 @@ def get_contract_creation_receipts(
# TODO: Handle when code is nonzero but doesn't match
# TODO: Handle when code is empty after it's not (re-init)
elif HexBytes(self.get_code(address, block_id=mid_block)) == contract_code:
# If the code exists, we need to look backwards.
yield from self.get_contract_creation_receipts(
address,
start_block=start_block,
Expand All @@ -1366,6 +1367,7 @@ def get_contract_creation_receipts(
)

elif mid_block + 1 <= stop_block:
# The code does not exist yet, we need to look ahead.
yield from self.get_contract_creation_receipts(
address,
start_block=mid_block + 1,
Expand Down
6 changes: 6 additions & 0 deletions src/ape_geth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .provider import Geth as GethProvider
from .provider import GethConfig, GethDev, GethNetworkConfig
from .query import OTSQueryEngine


@plugins.register(plugins.Config)
Expand All @@ -18,3 +19,8 @@ def providers():
yield "ethereum", network_name, GethProvider

yield "ethereum", LOCAL_NETWORK_NAME, GethDev


@plugins.register(plugins.QueryPlugin)
def query_engines():
yield OTSQueryEngine
60 changes: 53 additions & 7 deletions src/ape_geth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import sys
from abc import ABC
from functools import cached_property
from itertools import tee
from pathlib import Path
from subprocess import DEVNULL, PIPE, Popen
Expand Down Expand Up @@ -38,15 +39,21 @@
from ape._pydantic_compat import Extra
from ape.api import (
PluginConfig,
ReceiptAPI,
SubprocessProvider,
TestProviderAPI,
TransactionAPI,
UpstreamProvider,
Web3Provider,
)
from ape.exceptions import APINotImplementedError, ProviderError
from ape.exceptions import (
ApeException,
APINotImplementedError,
ContractNotFoundError,
ProviderError,
)
from ape.logging import LogLevel, logger
from ape.types import CallTreeNode, SnapshotID, SourceTraceback, TraceFrame
from ape.types import AddressType, CallTreeNode, SnapshotID, SourceTraceback, TraceFrame
from ape.utils import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_CHAIN_ID,
Expand Down Expand Up @@ -290,23 +297,43 @@ def data_dir(self) -> Path:

return _get_default_data_dir()

@cached_property
def _ots_api_level(self) -> Optional[int]:
# NOTE: Returns None when OTS namespace is not enabled.
try:
result = self._make_request("ots_getApiLevel")
except (NotImplementedError, ApeException, ValueError):
return None

if isinstance(result, int):
return result

elif isinstance(result, str) and result.isnumeric():
return int(result)

return None

def _set_web3(self):
self._client_version = None # Clear cached version when connecting to another URI.
self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path)

def _complete_connect(self):
if "geth" in self.client_version.lower():
client_version = self.client_version.lower()
if "geth" in client_version:
self._log_connection("Geth")
elif "erigon" in self.client_version.lower():
elif "reth" in client_version:
self._log_connection("Reth")
elif "erigon" in client_version:
self._log_connection("Erigon")
self.concurrency = 8
self.block_page_size = 40_000
elif "nethermind" in self.client_version.lower():
elif "nethermind" in client_version:
self._log_connection("Nethermind")
self.concurrency = 32
self.block_page_size = 50_000
else:
client_name = self.client_version.split("/")[0]
client_name = client_version.split("/")[0]
logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.")
logger.warning(f"Connecting Geth plugin to non-Geth client '{client_name}'.")

self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy)
Expand Down Expand Up @@ -399,7 +426,26 @@ def _log_connection(self, client_name: str):
)
logger.info(f"{msg} {suffix}.")

def _make_request(self, endpoint: str, parameters: List) -> Any:
def ots_get_contract_creator(self, address: AddressType) -> Optional[Dict]:
if self._ots_api_level is None:
return None

result = self._make_request("ots_getContractCreator", [address])
if result is None:
# NOTE: Skip the explorer part of the error message via `has_explorer=True`.
raise ContractNotFoundError(address, has_explorer=True, provider_name=self.name)

return result

def _get_contract_creation_receipt(self, address: AddressType) -> Optional[ReceiptAPI]:
if result := self.ots_get_contract_creator(address):
tx_hash = result["hash"]
return self.get_receipt(tx_hash)

return None

def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any:
parameters = parameters or []
try:
return super()._make_request(endpoint, parameters)
except ProviderError as err:
Expand Down
35 changes: 35 additions & 0 deletions src/ape_geth/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from functools import singledispatchmethod
from typing import Iterator, Optional

from ape.api import ReceiptAPI
from ape.api.query import ContractCreationQuery, QueryAPI, QueryType
from ape.exceptions import QueryEngineError
from ape_geth.provider import BaseGethProvider


class OTSQueryEngine(QueryAPI):
@singledispatchmethod
def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore[override]
return None

@singledispatchmethod
def perform_query(self, query: QueryType) -> Iterator: # type: ignore[override]
raise QueryEngineError(
f"{self.__class__.__name__} cannot handle {query.__class__.__name__} queries."
)

@estimate_query.register
def estimate_contract_creation_query(self, query: ContractCreationQuery) -> Optional[int]:
if provider := self.network_manager.active_provider:
if not isinstance(provider, BaseGethProvider):
return None
elif uri := provider.http_uri:
return 225 if uri.startswith("http://") else 600

return None

@perform_query.register
def get_contract_creation_receipt(self, query: ContractCreationQuery) -> Iterator[ReceiptAPI]:
if self.network_manager.active_provider and isinstance(self.provider, BaseGethProvider):
if receipt := self.provider._get_contract_creation_receipt(query.contract):
yield receipt
31 changes: 31 additions & 0 deletions tests/functional/geth/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Tuple

from ape.exceptions import ChainError
from tests.conftest import geth_process_test


@geth_process_test
def test_get_contract_creation_receipts(mock_geth, geth_contract, chain, networks, geth_provider):
geth_provider.__dict__["explorer"] = None
provider = networks.active_provider
networks.active_provider = mock_geth
mock_geth._web3.eth.get_block.side_effect = geth_provider.get_block

try:
mock_geth._web3.eth.get_code.return_value = b"123"

# NOTE: Due to mocks, this next part may not actually find the contract.
# but that is ok but we mostly want to make sure it tries OTS. There
# are other tests for the brute-force logic.
try:
next(chain.contracts.get_creation_receipt(geth_contract.address), None)
except ChainError:
pass

# Ensure we tried using OTS.
actual = mock_geth._web3.provider.make_request.call_args
expected: Tuple[str, List] = ("ots_getApiLevel", [])
assert any(arguments == expected for arguments in actual)

finally:
networks.active_provider = provider

0 comments on commit 4c5b853

Please sign in to comment.