diff --git a/crates/engine/local/src/service.rs b/crates/engine/local/src/service.rs index dd6a644acb8b7..5c55812b1dffa 100644 --- a/crates/engine/local/src/service.rs +++ b/crates/engine/local/src/service.rs @@ -27,7 +27,7 @@ use reth_engine_tree::{ RequestHandlerEvent, }, persistence::PersistenceHandle, - tree::{EngineApiTreeHandler, InvalidBlockHook, TreeConfig}, + tree::{root::BasicStateRootTaskFactory, EngineApiTreeHandler, InvalidBlockHook, TreeConfig}, }; use reth_evm::{execute::BlockExecutorProvider, ConfigureEvm}; use reth_node_types::{BlockTy, HeaderTy, TxTy}; @@ -95,8 +95,10 @@ where PersistenceHandle::::spawn_service(provider, pruner, sync_metrics_tx); let canonical_in_memory_state = blockchain_db.canonical_in_memory_state(); + let state_root_task_factory = BasicStateRootTaskFactory::new(); + let (to_tree_tx, from_tree) = - EngineApiTreeHandler::::spawn_new( + EngineApiTreeHandler::::spawn_new( blockchain_db.clone(), executor_factory, consensus, @@ -108,6 +110,7 @@ where invalid_block_hook, engine_kind, evm_config, + state_root_task_factory, ); let handler = EngineApiRequestHandler::new(to_tree_tx, from_tree); diff --git a/crates/engine/service/src/service.rs b/crates/engine/service/src/service.rs index f6b791e744c00..917e1f50332d9 100644 --- a/crates/engine/service/src/service.rs +++ b/crates/engine/service/src/service.rs @@ -8,7 +8,7 @@ use reth_engine_tree::{ download::BasicBlockDownloader, engine::{EngineApiKind, EngineApiRequest, EngineApiRequestHandler, EngineHandler}, persistence::PersistenceHandle, - tree::{EngineApiTreeHandler, InvalidBlockHook, TreeConfig}, + tree::{root::BasicStateRootTaskFactory, EngineApiTreeHandler, InvalidBlockHook, TreeConfig}, }; pub use reth_engine_tree::{ chain::{ChainEvent, ChainOrchestrator}, @@ -105,8 +105,10 @@ where let canonical_in_memory_state = blockchain_db.canonical_in_memory_state(); + let state_root_task_factory = BasicStateRootTaskFactory::new(); + let (to_tree_tx, from_tree) = - EngineApiTreeHandler::::spawn_new( + EngineApiTreeHandler::::spawn_new( blockchain_db, executor_factory, consensus, @@ -118,6 +120,7 @@ where invalid_block_hook, engine_kind, evm_config, + state_root_task_factory, ); let engine_handler = EngineApiRequestHandler::new(to_tree_tx, from_tree); diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs index 912daa72b7bbf..85c2395786fbf 100644 --- a/crates/engine/tree/src/tree/mod.rs +++ b/crates/engine/tree/src/tree/mod.rs @@ -3,10 +3,6 @@ use crate::{ chain::FromOrchestrator, engine::{DownloadRequest, EngineApiEvent, EngineApiKind, EngineApiRequest, FromEngine}, persistence::PersistenceHandle, - tree::{ - cached_state::{CachedStateMetrics, CachedStateProvider, ProviderCacheBuilder}, - metrics::EngineApiMetrics, - }, }; use alloy_consensus::{transaction::Recovered, BlockHeader}; use alloy_eips::BlockNumHash; @@ -18,9 +14,11 @@ use alloy_primitives::{ use alloy_rpc_types_engine::{ ForkchoiceState, PayloadStatus, PayloadStatusEnum, PayloadValidationError, }; -use cached_state::{ProviderCaches, SavedCache}; +use cached_state::{ + CachedStateMetrics, CachedStateProvider, ProviderCacheBuilder, ProviderCaches, SavedCache, +}; use error::{InsertBlockError, InsertBlockErrorKind, InsertBlockFatalError}; -use metrics::PrewarmThreadMetrics; +use metrics::{EngineApiMetrics, PrewarmThreadMetrics}; use persistence_state::CurrentPersistenceAction; use reth_chain_state::{ CanonicalInMemoryState, ExecutedBlock, ExecutedBlockWithTrieUpdates, @@ -59,7 +57,8 @@ use reth_trie::{ use reth_trie_db::DatabaseTrieCursorFactory; use reth_trie_parallel::root::{ParallelStateRoot, ParallelStateRootError}; use root::{ - StateRootComputeOutcome, StateRootConfig, StateRootHandle, StateRootMessage, StateRootTask, + StateRootComputeHandle, StateRootComputeOutcome, StateRootConfig, StateRootMessage, + StateRootTaskFactory, StateRootTaskRunner, }; use std::{ cmp::Ordering, @@ -556,10 +555,11 @@ pub enum TreeAction { /// /// This type is responsible for processing engine API requests, maintaining the canonical state and /// emitting events. -pub struct EngineApiTreeHandler +pub struct EngineApiTreeHandler where N: NodePrimitives, T: EngineTypes, + F: StateRootTaskFactory

, { provider: P, executor_provider: E, @@ -603,14 +603,15 @@ where engine_kind: EngineApiKind, /// The most recent cache used for execution. most_recent_cache: Option, - /// Thread pool used for the state root task and prewarming - thread_pool: Arc, + /// Factory for state root tasks. + state_root_task_factory: F, } -impl std::fmt::Debug - for EngineApiTreeHandler +impl std::fmt::Debug + for EngineApiTreeHandler where N: NodePrimitives, + F: StateRootTaskFactory

, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EngineApiTreeHandler") @@ -634,7 +635,7 @@ where } } -impl EngineApiTreeHandler +impl EngineApiTreeHandler where N: NodePrimitives, P: DatabaseProviderFactory @@ -651,6 +652,7 @@ where C: ConfigureEvm

, T: EngineTypes, V: EngineValidator, + F: StateRootTaskFactory

, { /// Creates a new [`EngineApiTreeHandler`]. #[expect(clippy::too_many_arguments)] @@ -668,19 +670,10 @@ where config: TreeConfig, engine_kind: EngineApiKind, evm_config: C, + state_root_task_factory: F, ) -> Self { let (incoming_tx, incoming) = std::sync::mpsc::channel(); - let num_threads = root::thread_pool_size(); - - let thread_pool = Arc::new( - rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) - .thread_name(|i| format!("srt-worker-{}", i)) - .build() - .expect("Failed to create proof worker thread pool"), - ); - Self { provider, executor_provider, @@ -701,7 +694,7 @@ where invalid_block_hook: Box::new(NoopInvalidBlockHook), engine_kind, most_recent_cache: None, - thread_pool, + state_root_task_factory, } } @@ -728,6 +721,7 @@ where invalid_block_hook: Box>, kind: EngineApiKind, evm_config: C, + state_root_task_factory: F, ) -> (Sender, N::Block>>, UnboundedReceiver>) { let best_block_number = provider.best_block_number().unwrap_or(0); @@ -760,6 +754,7 @@ where config, kind, evm_config, + state_root_task_factory, ); task.set_invalid_block_hook(invalid_block_hook); let incoming = task.incoming_tx.clone(); @@ -2462,9 +2457,9 @@ where .set(config_elapsed.as_secs_f64()); let state_root_task = - StateRootTask::new(state_root_config.clone(), self.thread_pool.clone()); + self.state_root_task_factory.create_task(state_root_config.clone()); let state_root_sender = state_root_task.state_root_message_sender(); - let state_hook = Box::new(state_root_task.state_hook()) as Box; + let state_hook = state_root_task.state_hook() as Box; ( Some(state_root_task.spawn()), Some(state_root_config), @@ -2744,7 +2739,7 @@ where let evm_config = self.evm_config.clone(); // spawn task executing the individual tx - self.thread_pool.spawn(move || { + self.state_root_task_factory.thread_pool().spawn(move || { let thread_start = Instant::now(); let in_progress = task_finished.read().unwrap(); @@ -2898,7 +2893,7 @@ where /// the hash builder-based state root calculation if it fails. fn handle_state_root_result( &self, - state_root_handle: StateRootHandle, + state_root_handle: ::ResultHandle, state_root_task_config: StateRootConfig

, sealed_block: &SealedBlock, hashed_state: &HashedPostState, @@ -3235,7 +3230,10 @@ pub enum InsertPayloadOk { #[cfg(test)] mod tests { - use super::*; + use super::{ + root::{StateRootComputeHandle, StateRootTaskRunner}, + *, + }; use crate::persistence::PersistenceAction; use alloy_consensus::Header; use alloy_primitives::Bytes; @@ -3251,11 +3249,12 @@ mod tests { use reth_ethereum_consensus::EthBeaconConsensus; use reth_ethereum_engine_primitives::{EthEngineTypes, EthereumEngineValidator}; use reth_ethereum_primitives::{Block, EthPrimitives}; - use reth_evm::test_utils::MockExecutorProvider; + use reth_evm::{system_calls::StateChangeSource, test_utils::MockExecutorProvider}; use reth_evm_ethereum::EthEvmConfig; use reth_primitives_traits::Block as _; use reth_provider::test_utils::MockEthProvider; use reth_trie::{updates::TrieUpdates, HashedPostState}; + use revm_primitives::EvmState; use std::{ str::FromStr, sync::mpsc::{channel, Sender}, @@ -3316,6 +3315,73 @@ mod tests { } } + struct MockStateRootHandle { + root: B256, + } + + impl StateRootComputeHandle for MockStateRootHandle { + fn wait_for_result(self) -> Result { + Ok(StateRootComputeOutcome { + state_root: (self.root, TrieUpdates::default()), + total_time: Duration::from_secs(0), + time_from_last_update: Duration::from_secs(0), + }) + } + } + + struct MockStateRootTask { + root: B256, + } + + impl MockStateRootTask { + fn new(root: B256) -> Self { + Self { root } + } + } + + impl StateRootTaskRunner for MockStateRootTask { + type ResultHandle = MockStateRootHandle; + + fn spawn(self) -> Self::ResultHandle { + MockStateRootHandle { root: self.root } + } + + fn state_hook(&self) -> Box { + Box::new(move |_: StateChangeSource, _: &EvmState| {}) + } + + fn state_root_message_sender(&self) -> Sender { + let (tx, _rx) = channel(); + tx + } + } + + struct MockStateRootTaskFactory { + roots: Vec, + } + + impl MockStateRootTaskFactory { + fn new() -> Self { + Self { roots: Vec::new() } + } + + fn add_state_root(&mut self, root: B256) { + self.roots.push(root); + } + } + + impl StateRootTaskFactory for MockStateRootTaskFactory { + type Runner = MockStateRootTask; + + fn create_task(&mut self, _config: StateRootConfig) -> Self::Runner { + MockStateRootTask::new(self.roots.pop().unwrap()) + } + + fn thread_pool(&self) -> Arc { + Arc::new(rayon::ThreadPoolBuilder::new().build().unwrap()) + } + } + struct TestHarness { tree: EngineApiTreeHandler< EthPrimitives, @@ -3324,6 +3390,7 @@ mod tests { EthEngineTypes, EthereumEngineValidator, EthEvmConfig, + MockStateRootTaskFactory, >, to_tree_tx: Sender, Block>>, from_tree_rx: UnboundedReceiver, @@ -3372,6 +3439,8 @@ mod tests { let evm_config = EthEvmConfig::new(chain_spec.clone()); + let state_root_task_factory = MockStateRootTaskFactory::new(); + let tree = EngineApiTreeHandler::new( provider.clone(), executor_provider.clone(), @@ -3383,10 +3452,10 @@ mod tests { persistence_handle, PersistenceState::default(), payload_builder, - // TODO: fix tests for state root task https://github.com/paradigmxyz/reth/issues/14376 - TreeConfig::default().with_legacy_state_root(true), + TreeConfig::default(), EngineApiKind::Ethereum, evm_config, + state_root_task_factory, ); let block_builder = TestBlockBuilder::default().with_chain_spec((*chain_spec).clone()); @@ -3463,7 +3532,7 @@ mod tests { ) -> Result> { let execution_outcome = self.block_builder.get_execution_outcome(block.clone()); self.extend_execution_outcome([execution_outcome]); - self.tree.provider.add_state_root(block.state_root); + self.tree.state_root_task_factory.add_state_root(block.state_root); self.tree.insert_block(block) } @@ -3666,7 +3735,7 @@ mod tests { } else { block.state_root }; - self.tree.provider.add_state_root(state_root); + self.tree.state_root_task_factory.add_state_root(state_root); execution_outcomes.push(execution_outcome); } self.extend_execution_outcome(execution_outcomes); diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 463c41927b6e2..e06ca25fdb01b 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -492,6 +492,110 @@ struct StateRootTaskMetrics { pub state_root_iterations_histogram: Histogram, } +/// Result provided by a state root calculation +pub trait StateRootComputeHandle: Send + 'static { + /// Waits for the state root calculation to complete and returns the result. + fn wait_for_result(self) -> Result; +} + +impl StateRootComputeHandle for StateRootHandle { + fn wait_for_result(self) -> Result { + self.wait_for_result() + } +} + +/// Public interface of the state root task. +pub trait StateRootTaskRunner: Send + 'static { + /// Type of state root result returned by this runner + type ResultHandle: StateRootComputeHandle; + + /// Spawns the state root task and returns a handle to await its result. + fn spawn(self) -> Self::ResultHandle; + + /// Returns a state hook that can be used to send state updates to this task. + fn state_hook(&self) -> Box; + + /// Returns a [`StateRootMessage`] sender. + fn state_root_message_sender(&self) -> Sender; +} + +impl StateRootTaskRunner for StateRootTask +where + Factory: + DatabaseProviderFactory + StateCommitmentProvider + Clone + 'static, +{ + type ResultHandle = StateRootHandle; + + fn spawn(self) -> StateRootHandle { + self.spawn() + } + + fn state_hook(&self) -> Box { + Box::new(self.state_hook()) + } + + fn state_root_message_sender(&self) -> Sender { + self.state_root_message_sender() + } +} + +/// Factory trait for creating state root task runners. +pub trait StateRootTaskFactory: Send + 'static { + /// The type of task runner this factory creates. + type Runner: StateRootTaskRunner; + + /// Creates a new state root task runner. + fn create_task(&mut self, config: StateRootConfig) -> Self::Runner; + + /// Creates a new state root task runner. + fn thread_pool(&self) -> Arc; +} + +/// Factory for creating real state root tasks. +#[derive(Debug)] +pub struct BasicStateRootTaskFactory { + /// Thread pool used for parallel proof generation. + thread_pool: Arc, +} + +impl BasicStateRootTaskFactory { + /// Creates a new factory. + pub fn new() -> Self { + let num_threads = thread_pool_size(); + let thread_pool = Arc::new( + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .thread_name(|i| format!("srt-worker-{}", i)) + .build() + .expect("Failed to create proof worker thread pool"), + ); + + Self { thread_pool } + } +} + +impl Default for BasicStateRootTaskFactory { + fn default() -> Self { + Self::new() + } +} + +impl StateRootTaskFactory for BasicStateRootTaskFactory +where + Provider: + DatabaseProviderFactory + StateCommitmentProvider + Clone + 'static, +{ + type Runner = StateRootTask; + + fn create_task(&mut self, config: StateRootConfig) -> Self::Runner { + StateRootTask::new(config, self.thread_pool.clone()) + } + + fn thread_pool(&self) -> Arc { + self.thread_pool.clone() + } +} + /// Standalone task that receives a transaction state stream and updates relevant /// data structures to calculate state root. ///