Skip to content

Commit

Permalink
fix(starknet_class_manager): fix get_executable to fallback to depr…
Browse files Browse the repository at this point in the history
…ecated (#4130)

Added a test that failed before the fix.
  • Loading branch information
elintul authored Feb 13, 2025
1 parent cd92f4c commit 7ca1752
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 4 deletions.
51 changes: 50 additions & 1 deletion crates/starknet_class_manager/src/class_manager_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use mockall::predicate::eq;
use starknet_api::core::CompiledClassHash;
use starknet_api::core::{ClassHash, CompiledClassHash};
use starknet_api::felt;
use starknet_api::state::SierraContractClass;
use starknet_class_manager_types::ClassHashes;
Expand All @@ -27,6 +27,9 @@ impl ClassManager<FsClassStorage> {
}
}

// TODO(Elin): consider sharing setup code, keeping it clear for the test reader how the compiler is
// mocked per test.

#[tokio::test]
async fn class_manager() {
// Setup.
Expand Down Expand Up @@ -68,3 +71,49 @@ async fn class_manager() {
let class_hashes = class_manager.add_class(class).await.unwrap();
assert_eq!(class_hashes, expected_class_hashes);
}

#[tokio::test]
#[ignore = "Test deprecated class API"]
async fn class_manager_deprecated_class_api() {
todo!("Test deprecated class API");
}

#[tokio::test]
async fn class_manager_get_executable() {
// Setup.

// Prepare mock compiler.
let mut compiler = MockSierraCompilerClient::new();
let class = RawClass::try_from(SierraContractClass::default()).unwrap();
let expected_executable_class = RawExecutableClass(vec![4, 5, 6].into());
let expected_executable_class_for_closure = expected_executable_class.clone();
let expected_executable_class_hash = CompiledClassHash(felt!("0x5678"));
compiler.expect_compile().with(eq(class.clone())).times(1).return_once(move |_| {
Ok((expected_executable_class_for_closure, expected_executable_class_hash))
});

// Prepare class manager.
let persistent_root = create_tmp_dir().unwrap();
let class_hash_storage_path_prefix = create_tmp_dir().unwrap();
let mut class_manager =
ClassManager::new_for_testing(compiler, &persistent_root, &class_hash_storage_path_prefix);

// Test.

// Add classes: deprecated and non-deprecated, under different hashes.
let ClassHashes { class_hash, executable_class_hash: _ } =
class_manager.add_class(class.clone()).await.unwrap();

let deprecated_class_hash = ClassHash(felt!("0x1806"));
let deprecated_executable_class = RawExecutableClass(vec![1, 2, 3].into());
class_manager
.add_deprecated_class(deprecated_class_hash, deprecated_executable_class.clone())
.unwrap();

// Get both executable classes.
assert_eq!(class_manager.get_executable(class_hash).unwrap(), Some(expected_executable_class));
assert_eq!(
class_manager.get_executable(deprecated_class_hash).unwrap(),
Some(deprecated_executable_class)
);
}
22 changes: 19 additions & 3 deletions crates/starknet_class_manager/src/class_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use papyrus_storage::class_hash::{ClassHashStorageReader, ClassHashStorageWriter};
use serde::{Deserialize, Serialize};
use starknet_api::class_cache::GlobalContractCache;
use starknet_api::contract_class::ContractClass;
use starknet_api::core::ChainId;
use starknet_class_manager_types::{CachedClassStorageError, ClassId, ExecutableClassHash};
use starknet_sierra_multicompile_types::{RawClass, RawExecutableClass};
Expand Down Expand Up @@ -164,15 +165,28 @@ impl<S: ClassStorage> ClassStorage for CachedClassStorage<S> {
}

fn get_executable(&self, class_id: ClassId) -> Result<Option<RawExecutableClass>, Self::Error> {
if let Some(class) = self.executable_classes.get(&class_id) {
if let Some(class) = self
.executable_classes
.get(&class_id)
.or_else(|| self.deprecated_classes.get(&class_id))
{
return Ok(Some(class));
}

let Some(class) = self.storage.get_executable(class_id)? else {
return Ok(None);
};

self.executable_classes.set(class_id, class.clone());
// TODO(Elin): separate Cairo0<>1 getters to avoid deserializing here.
match ContractClass::try_from(class.clone()).unwrap() {
ContractClass::V0(_) => {
self.deprecated_classes.set(class_id, class.clone());
}
ContractClass::V1(_) => {
self.executable_classes.set(class_id, class.clone());
}
}

Ok(Some(class))
}

Expand Down Expand Up @@ -440,7 +454,9 @@ impl ClassStorage for FsClassStorage {
}

fn get_executable(&self, class_id: ClassId) -> Result<Option<RawExecutableClass>, Self::Error> {
if !self.contains_class(class_id)? {
let contains_class =
self.contains_class(class_id)? || self.contains_deprecated_class(class_id);
if !contains_class {
return Ok(None);
}

Expand Down

0 comments on commit 7ca1752

Please sign in to comment.