diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 6855cbce88..1206bbe087 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -1005,7 +1005,6 @@ def chain_id(self) -> int: **NOTE**: Unless overridden, returns same as :py:attr:`ape.api.providers.ProviderAPI.chain_id`. """ - return self.provider.chain_id @property diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 0a1866a4b6..475f00577f 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -757,7 +757,6 @@ def chain_id(self) -> int: The blockchain ID. See `ChainList `__ for a comprehensive list of IDs. """ - network_name = self.provider.network.name if network_name not in self._chain_id_map: self._chain_id_map[network_name] = self.provider.chain_id diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index a5989084ad..40a3706890 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -572,13 +572,13 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = @cached_property def chain_id(self) -> int: default_chain_id = None - if self.network.name != "custom" and not self.network.is_dev: - # If using a live network, the chain ID is hardcoded. + if self.network.name not in ("adhoc", "custom") and not self.network.is_dev: + # If using a live plugin-based network, the chain ID is hardcoded. default_chain_id = self.network.chain_id try: if hasattr(self.web3, "eth"): - return self.web3.eth.chain_id + return self._get_chain_id() except ProviderNotConnectedError: if default_chain_id is not None: @@ -586,6 +586,14 @@ def chain_id(self) -> int: raise # Original error + except ValueError as err: + # Possible syncing error. + raise ProviderError( + err.args[0].get("message") + if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict))) + else "Error getting chain ID." + ) + if default_chain_id is not None: return default_chain_id @@ -606,6 +614,10 @@ def priority_fee(self) -> int: "eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually." ) from err + def _get_chain_id(self) -> int: + result = self.make_request("eth_chainId", []) + return result if isinstance(result, int) else int(result, 16) + def get_block(self, block_id: "BlockID") -> BlockAPI: if isinstance(block_id, str) and block_id.isnumeric(): block_id = int(block_id) @@ -1603,15 +1615,7 @@ def _complete_connect(self): if not self.network.is_dev: self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) - # Check for chain errors, including syncing - try: - chain_id = self.web3.eth.chain_id - except ValueError as err: - raise ProviderError( - err.args[0].get("message") - if all((hasattr(err, "args"), err.args, isinstance(err.args[0], dict))) - else "Error getting chain id." - ) + chain_id = self.chain_id # NOTE: We have to check both earliest and latest # because if the chain was _ever_ PoA, we need diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index 2af47b08ab..b2f32b8a57 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -240,7 +240,14 @@ def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum): to fetch blocks during the PoA portion of the chain. """ mock_web3.eth.get_block.side_effect = ExtraDataLengthError - mock_web3.eth.chain_id = ethereum.sepolia.chain_id + + def make_request(rpc, arguments): + if rpc == "eth_chainId": + return {"result": ethereum.sepolia.chain_id} + + return None + + mock_web3.provider.make_request.side_effect = make_request web3_factory.return_value = mock_web3 provider = ethereum.sepolia.get_provider("node") provider.provider_settings = {"uri": "http://node.example.com"} # fake diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index 57b6b6d5e1..ae959ec3af 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -105,6 +105,35 @@ def test_chain_id_is_cached(eth_tester_provider): eth_tester_provider._web3 = web3 # Undo +def test_chain_id_from_ethereum_base_provider_is_cached(mock_web3, ethereum, eth_tester_provider): + """ + Simulated chain ID from a plugin (using base-ethereum class) to ensure is + also cached. + """ + + def make_request(rpc, arguments): + if rpc == "eth_chainId": + return {"result": 11155111} # Sepolia + + return eth_tester_provider.make_request(rpc, arguments) + + mock_web3.provider.make_request.side_effect = make_request + + class PluginProvider(Web3Provider): + def connect(self): + return + + def disconnect(self): + return + + provider = PluginProvider(name="sim", network=ethereum.sepolia) + provider._web3 = mock_web3 + assert provider.chain_id == 11155111 + # Unset to web3 to prove it does not check it again (else it would fail). + provider._web3 = None + assert provider.chain_id == 11155111 + + def test_chain_id_when_disconnected(eth_tester_provider): eth_tester_provider.disconnect() try: @@ -658,3 +687,32 @@ def test_update_settings_invalidates_snapshots(eth_tester_provider, chain): assert snapshot in chain._snapshots[eth_tester_provider.chain_id] eth_tester_provider.update_settings({}) assert snapshot not in chain._snapshots[eth_tester_provider.chain_id] + + +def test_connect_uses_cached_chain_id(mocker, mock_web3, ethereum, eth_tester_provider): + class PluginProvider(EthereumNodeProvider): + pass + + web3_factory_patch = mocker.patch("ape_ethereum.provider._create_web3") + web3_factory_patch.return_value = mock_web3 + + class ChainIDTracker: + call_count = 0 + + def make_request(self, rpc, args): + if rpc == "eth_chainId": + self.call_count += 1 + return {"result": "0xaa36a7"} # Sepolia + + return eth_tester_provider.make_request(rpc, args) + + chain_id_tracker = ChainIDTracker() + mock_web3.provider.make_request.side_effect = chain_id_tracker.make_request + + provider = PluginProvider(name="node", network=ethereum.sepolia) + provider.connect() + assert chain_id_tracker.call_count == 1 + provider.disconnect() + provider.connect() + # It is still cached from the previous connection. + assert chain_id_tracker.call_count == 1