Skip to content
This repository has been archived by the owner on Jul 22, 2024. It is now read-only.

update and use new native JIT api #1139

Merged
merged 14 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ cairo-lang-runner = { workspace = true }
cairo-lang-sierra = { workspace = true }
cairo-lang-starknet = { workspace = true }
cairo-lang-utils = { workspace = true }
cairo-native = { git = "https://github.com/lambdaclass/cairo_native", rev = "f668096bd6382066392cd563873dda5e7885a388", optional = true }
cairo-native = { git = "https://github.com/lambdaclass/cairo_native", rev = "9b669cf8fefbff5e3dced87b70a5d957bdc3e85c", optional = true }
cairo-vm = { workspace = true }
flate2 = "1.0.25"
getset = "0.1.2"
Expand Down
32 changes: 21 additions & 11 deletions bench/internals.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![deny(warnings)]

#[cfg(feature = "cairo-native")]
use cairo_native::cache::ProgramCache;

use cairo_vm::felt;
use felt::{felt_str, Felt252};
use lazy_static::lazy_static;
Expand Down Expand Up @@ -56,22 +58,24 @@ fn scope<T>(f: impl FnOnce() -> T) -> T {
// FnOnce calls for each test, that are merged in the flamegraph.
fn main() {
#[cfg(feature = "cairo-native")]
let program_cache = Rc::new(RefCell::new(ProgramCache::new(
starknet_in_rust::utils::get_native_context(),
)));
{
let program_cache = Rc::new(RefCell::new(ProgramCache::new(
starknet_in_rust::utils::get_native_context(),
)));

deploy_account(program_cache.clone());
declare(program_cache.clone());
deploy(program_cache.clone());
invoke(program_cache.clone());
deploy_account(program_cache.clone());
declare(program_cache.clone());
deploy(program_cache.clone());
invoke(program_cache.clone());
}

// The black_box ensures there's no tail-call optimization.
// If not, the flamegraph ends up less nice.
black_box(());
}

#[inline(never)]
fn deploy_account(
pub fn deploy_account(
#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>,
) {
const RUNS: usize = 500;
Expand Down Expand Up @@ -120,7 +124,9 @@ fn deploy_account(
}

#[inline(never)]
fn declare(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>) {
pub fn declare(
#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>,
) {
const RUNS: usize = 5;

let state_reader = Arc::new(InMemoryStateReader::default());
Expand Down Expand Up @@ -160,7 +166,9 @@ fn declare(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCac
}

#[inline(never)]
fn deploy(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>) {
pub fn deploy(
#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>,
) {
const RUNS: usize = 8;

let state_reader = Arc::new(InMemoryStateReader::default());
Expand Down Expand Up @@ -206,7 +214,9 @@ fn deploy(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCach
}

#[inline(never)]
fn invoke(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>) {
pub fn invoke(
#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCache<ClassHash>>>,
) {
const RUNS: usize = 100;

let state_reader = Arc::new(InMemoryStateReader::default());
Expand Down
87 changes: 8 additions & 79 deletions src/execution/execution_entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@ use std::sync::Arc;
#[cfg(feature = "cairo-native")]
use {
crate::syscalls::native_syscall_handler::NativeSyscallHandler,
cairo_native::{
execution_result::NativeExecutionResult, metadata::syscall_handler::SyscallHandlerMeta,
utils::felt252_bigint,
},
core::cell::RefCell,
serde_json::Value,
std::rc::Rc,
cairo_native::metadata::syscall_handler::SyscallHandlerMeta, core::cell::RefCell, std::rc::Rc,
};

#[derive(Debug, Default)]
Expand Down Expand Up @@ -671,14 +665,11 @@ impl ExecutionEntryPoint {
class_hash: &ClassHash,
program_cache: Rc<RefCell<ProgramCache<'_, ClassHash>>>,
) -> Result<CallInfo, TransactionError> {
use cairo_native::values::JITValue;

use crate::{
syscalls::business_logic_syscall_handler::SYSCALL_BASE, utils::NATIVE_CONTEXT,
};
use cairo_lang_sierra::{
extensions::core::{CoreLibfunc, CoreType, CoreTypeConcrete},
program_registry::ProgramRegistry,
};
use serde_json::json;

// Ensure we're using the global context, if initialized.
if let Some(native_context) = NATIVE_CONTEXT.get() {
Expand Down Expand Up @@ -706,9 +697,6 @@ impl ExecutionEntryPoint {
.unwrap(),
};

let program_registry: ProgramRegistry<CoreType, CoreLibfunc> =
ProgramRegistry::new(sierra_program).unwrap();

let native_executor = {
let mut cache = program_cache.borrow_mut();
if let Some(executor) = cache.get(*class_hash) {
Expand Down Expand Up @@ -744,85 +732,26 @@ impl ExecutionEntryPoint {
.get_module_mut()
.insert_metadata(SyscallHandlerMeta::new(&mut syscall_handler));

let syscall_addr = native_executor
.borrow()
.get_module()
.get_metadata::<SyscallHandlerMeta>()
.unwrap()
.as_ptr()
.as_ptr() as *const () as usize;

let entry_point_fn = &sierra_program
.funcs
.iter()
.find(|x| x.id.id == (entry_point.function_idx as u64))
.unwrap();
let ret_types: Vec<&CoreTypeConcrete> = entry_point_fn
.signature
.ret_types
.iter()
.map(|x| program_registry.get_type(x).unwrap())
.collect();
let entry_point_id = &entry_point_fn.id;

let required_init_gas = native_executor
.borrow()
.get_module()
.get_required_init_gas(entry_point_id);
let entry_point_id = &entry_point_fn.id;

let calldata: Vec<_> = self
.calldata
.iter()
.map(|felt| felt252_bigint(felt.to_bigint()))
.collect();

/*
Below we construct `params`, the Serde value that MLIR expects. It consists of the following:

- One `null` value for each builtin that is going to be used.
- The maximum amout of gas allowed by the call.
- `syscall_addr`, the address of the syscall handler.
- `calldata`, an array of Felt arguments to the method being called.
*/

let wrapped_calldata = vec![calldata];
let params: Vec<Value> = sierra_program.funcs[entry_point_id.id as usize]
.params
.iter()
.map(|param| {
match param.ty.debug_name.as_ref().unwrap().as_str() {
"GasBuiltin" => {
json!(self.initial_gas)
}
"Pedersen" | "SegmentArena" | "RangeCheck" | "Bitwise" | "Poseidon" => {
json!(null)
}
"System" => {
json!(syscall_addr)
}
// calldata
"core::array::Span::<core::felt252>" => json!(wrapped_calldata),
x => {
unimplemented!("unhandled param type: {:?}", x);
}
}
})
.cloned()
.map(JITValue::Felt252)
.collect();

let mut writer: Vec<u8> = Vec::new();
let returns = &mut serde_json::Serializer::new(&mut writer);

native_executor
let value = native_executor
.borrow()
.execute(entry_point_id, json!(params), returns, required_init_gas)
.execute_contract(entry_point_id, &calldata, self.initial_gas)
.map_err(|e| TransactionError::CustomError(format!("cairo-native error: {:?}", e)))?;

let value = NativeExecutionResult::deserialize_from_ret_types(
&mut serde_json::Deserializer::from_slice(&writer),
&ret_types,
)
.expect("failed to serialize starknet execution result");

Ok(CallInfo {
caller_address: self.caller_address.clone(),
call_type: Some(self.call_type.clone()),
Expand Down
24 changes: 8 additions & 16 deletions tests/cairo_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ fn replace_class_test() {
let casm_replace_selector = &casm_entrypoints.external.get(0).unwrap().selector;

// Create state reader with class hash data
let mut contract_class_cache = PermanentContractClassCache::default();
let contract_class_cache = PermanentContractClassCache::default();

let address = Address(1111.into());
let casm_address = Address(2222.into());
Expand All @@ -906,7 +906,7 @@ fn replace_class_test() {

let nonce = Felt252::zero();

insert_sierra_class_into_cache(&mut contract_class_cache, CLASS_HASH_A, contract_class_a);
insert_sierra_class_into_cache(&contract_class_cache, CLASS_HASH_A, contract_class_a);

contract_class_cache.set_contract_class(
CASM_CLASS_HASH_A,
Expand Down Expand Up @@ -1066,7 +1066,7 @@ fn replace_class_contract_call() {

// Create state reader with class hash data
let contract_class_cache = PermanentContractClassCache::default();
let mut native_contract_class_cache = PermanentContractClassCache::default();
let native_contract_class_cache = PermanentContractClassCache::default();

let address = Address(Felt252::one());
let class_hash_a: ClassHash = ClassHash([1; 32]);
Expand All @@ -1076,11 +1076,7 @@ fn replace_class_contract_call() {
class_hash_a,
CompiledClass::Casm(Arc::new(casm_contract_class_a)),
);
insert_sierra_class_into_cache(
&mut native_contract_class_cache,
class_hash_a,
sierra_class_a,
);
insert_sierra_class_into_cache(&native_contract_class_cache, class_hash_a, sierra_class_a);

let mut state_reader = InMemoryStateReader::default();
state_reader
Expand Down Expand Up @@ -1114,11 +1110,7 @@ fn replace_class_contract_call() {
class_hash_b,
CompiledClass::Casm(Arc::new(contract_class_b)),
);
insert_sierra_class_into_cache(
&mut native_contract_class_cache,
class_hash_b,
sierra_class_b,
);
insert_sierra_class_into_cache(&native_contract_class_cache, class_hash_b, sierra_class_b);

// SET GET_NUMBER_WRAPPER

Expand Down Expand Up @@ -1151,7 +1143,7 @@ fn replace_class_contract_call() {
CompiledClass::Casm(Arc::new(wrapper_contract_class)),
);
insert_sierra_class_into_cache(
&mut native_contract_class_cache,
&native_contract_class_cache,
wrapper_class_hash,
wrapper_sierra_class,
);
Expand Down Expand Up @@ -1776,14 +1768,14 @@ fn get_execution_info_test() {
let selector = &entrypoints.external.get(0).unwrap().selector;

// Create state reader with class hash data
let mut contract_class_cache = PermanentContractClassCache::default();
let contract_class_cache = PermanentContractClassCache::default();

// Contract data
let address = Address(1111.into());
let class_hash: ClassHash = ClassHash([1; 32]);
let nonce = Felt252::zero();

insert_sierra_class_into_cache(&mut contract_class_cache, class_hash, sierra_contract_class);
insert_sierra_class_into_cache(&contract_class_cache, class_hash, sierra_contract_class);

let mut state_reader = InMemoryStateReader::default();

Expand Down