diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index dc0563ade502..ae22b036b65f 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -1,6 +1,7 @@ //! State root task related functionality. use alloy_primitives::map::{HashMap, HashSet}; +use reth_evm::system_calls::OnStateHook; use reth_provider::{ providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, StateCommitmentProvider, @@ -20,7 +21,7 @@ use std::{ collections::BTreeMap, ops::Deref, sync::{ - mpsc::{self, Receiver, Sender}, + mpsc::{self, channel, Receiver, Sender}, Arc, }, time::{Duration, Instant}, @@ -249,11 +250,9 @@ where + 'static, { /// Creates a new state root task with the unified message channel - pub(crate) fn new( - config: StateRootConfig, - tx: Sender, - rx: Receiver, - ) -> Self { + pub(crate) fn new(config: StateRootConfig) -> Self { + let (tx, rx) = channel(); + Self { config, rx, @@ -279,6 +278,15 @@ where StateRootHandle::new(rx) } + /// Returns a state hook to be used to send state updates to this task. + pub(crate) fn state_hook(&self) -> impl OnStateHook { + let state_hook = StateHookSender::new(self.tx.clone()); + + move |state: &EvmState| { + let _ = state_hook.send(StateRootMessage::StateUpdate(state.clone())); + } + } + /// Handles state updates. /// /// Returns proof targets derived from the state update. @@ -670,7 +678,6 @@ mod tests { reth_tracing::init_test_tracing(); let factory = create_test_provider_factory(); - let (tx, rx) = std::sync::mpsc::channel(); let state_updates = create_mock_state_updates(10, 10); let mut hashed_state = HashedPostState::default(); @@ -721,16 +728,14 @@ mod tests { consistent_view: ConsistentDbView::new(factory, None), input: Arc::new(TrieInput::from_state(hashed_state)), }; - let task = StateRootTask::new(config, tx.clone(), rx); + let task = StateRootTask::new(config); + let mut state_hook = task.state_hook(); let handle = task.spawn(); - let state_hook_sender = StateHookSender::new(tx); for update in state_updates { - state_hook_sender - .send(StateRootMessage::StateUpdate(update)) - .expect("failed to send state"); + state_hook.on_state(&update); } - drop(state_hook_sender); + drop(state_hook); let (root_from_task, _) = handle.wait_for_result().expect("task failed"); let root_from_base = state_root(accumulated_state);