diff --git a/crates/starknet_integration_tests/src/state_reader.rs b/crates/starknet_integration_tests/src/state_reader.rs index 7e752b2714..bf0ff923e5 100644 --- a/crates/starknet_integration_tests/src/state_reader.rs +++ b/crates/starknet_integration_tests/src/state_reader.rs @@ -63,6 +63,8 @@ impl StorageTestSetup { path: Option, ) -> Self { let preset_test_contracts = PresetTestContracts::new(); + let classes = + TestClasses::new(test_defined_accounts.clone(), preset_test_contracts.clone()); let batcher_db_path = path.as_ref().map(|p| p.join("batcher")); let ((_, mut batcher_storage_writer), batcher_storage_config, batcher_storage_handle) = @@ -75,6 +77,7 @@ impl StorageTestSetup { chain_info, test_defined_accounts.clone(), preset_test_contracts.clone(), + &classes, ); let state_sync_db_path = path.as_ref().map(|p| p.join("state_sync")); @@ -91,6 +94,7 @@ impl StorageTestSetup { chain_info, test_defined_accounts, preset_test_contracts, + &classes, ); // TODO(Yair): restructure this. @@ -124,12 +128,14 @@ fn create_test_state( chain_info: &ChainInfo, test_defined_accounts: Vec, preset_test_contracts: PresetTestContracts, + classes: &TestClasses, ) { initialize_papyrus_test_state( storage_writer, chain_info, test_defined_accounts, preset_test_contracts, + classes, ); } @@ -160,30 +166,40 @@ impl PresetTestContracts { } } +struct TestClasses { + pub cairo0_contract_classes: Vec<(ClassHash, DeprecatedContractClass)>, + pub sierra_vec: Vec<(ClassHash, SierraContractClass)>, + pub cairo1_contract_classes: Vec<(ClassHash, CasmContractClass)>, +} + +impl TestClasses { + pub fn new( + test_defined_accounts: Vec, + preset_test_contracts: PresetTestContracts, + ) -> TestClasses { + let contract_classes_to_retrieve = test_defined_accounts + .into_iter() + .map(|acc| acc.account) + .chain(preset_test_contracts.default_test_contracts) + .chain([preset_test_contracts.erc20_contract]); + let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone()); + let (cairo0_contract_classes, cairo1_contract_classes) = + prepare_compiled_contract_classes(contract_classes_to_retrieve); + + Self { sierra_vec, cairo0_contract_classes, cairo1_contract_classes } + } +} + fn initialize_papyrus_test_state( storage_writer: &mut StorageWriter, chain_info: &ChainInfo, test_defined_accounts: Vec, preset_test_contracts: PresetTestContracts, + classes: &TestClasses, ) { let state_diff = prepare_state_diff(chain_info, &test_defined_accounts, &preset_test_contracts); - let contract_classes_to_retrieve = test_defined_accounts - .into_iter() - .map(|acc| acc.account) - .chain(preset_test_contracts.default_test_contracts) - .chain([preset_test_contracts.erc20_contract]); - let sierra_vec: Vec<_> = prepare_sierra_classes(contract_classes_to_retrieve.clone()); - let (cairo0_contract_classes, cairo1_contract_classes) = - prepare_compiled_contract_classes(contract_classes_to_retrieve); - - write_state_to_papyrus_storage( - storage_writer, - state_diff, - &cairo0_contract_classes, - &cairo1_contract_classes, - &sierra_vec, - ) + write_state_to_papyrus_storage(storage_writer, state_diff, classes) } fn prepare_state_diff( @@ -266,12 +282,12 @@ fn prepare_compiled_contract_classes( fn write_state_to_papyrus_storage( storage_writer: &mut StorageWriter, state_diff: ThinStateDiff, - cairo0_contract_classes: &[(ClassHash, DeprecatedContractClass)], - cairo1_contract_classes: &[(ClassHash, CasmContractClass)], - cairo1_sierra: &[(ClassHash, SierraContractClass)], + classes: &TestClasses, ) { let block_number = BlockNumber(0); let block_header = test_block_header(block_number); + let TestClasses { sierra_vec: cairo1_sierra, cairo0_contract_classes, cairo1_contract_classes } = + classes; let cairo0_contract_classes: Vec<_> = cairo0_contract_classes.iter().map(|(hash, contract)| (*hash, contract)).collect();