Skip to content

Commit

Permalink
feat(starknet_os): split Hint to OsHint and syscall + test syscall eq…
Browse files Browse the repository at this point in the history
…uivalence with blockifier
  • Loading branch information
TzahiTaub committed Feb 9, 2025
1 parent 250f9df commit a89711e
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 92 deletions.
178 changes: 91 additions & 87 deletions crates/starknet_os/src/hints/enum_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,97 @@ use crate::{define_hint_enum, define_hint_extension_enum};
pub mod test;

define_hint_enum!(
Hint,
SyscallHint,
(
CallContract,
call_contract,
"syscall_handler.call_contract(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
DelegateCall,
delegate_call,
"syscall_handler.delegate_call(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
DelegateL1Handler,
delegate_l1_handler,
"syscall_handler.delegate_l1_handler(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(Deploy, deploy, "syscall_handler.deploy(segments=segments, syscall_ptr=ids.syscall_ptr)"),
(
EmitEvent,
emit_event,
"syscall_handler.emit_event(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetBlockNumber,
get_block_number,
"syscall_handler.get_block_number(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetBlockTimestamp,
get_block_timestamp,
"syscall_handler.get_block_timestamp(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetCallerAddress,
get_caller_address,
"syscall_handler.get_caller_address(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetContractAddress,
get_contract_address,
"syscall_handler.get_contract_address(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetSequencerAddress,
get_sequencer_address,
"syscall_handler.get_sequencer_address(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetTxInfo,
get_tx_info,
"syscall_handler.get_tx_info(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetTxSignature,
get_tx_signature,
"syscall_handler.get_tx_signature(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
LibraryCall,
library_call,
"syscall_handler.library_call(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
LibraryCallL1Handler,
library_call_l1_handler,
"syscall_handler.library_call_l1_handler(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
ReplaceClass,
replace_class,
"syscall_handler.replace_class(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
SendMessageToL1,
send_message_to_l1,
"syscall_handler.send_message_to_l1(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
StorageRead,
storage_read,
"syscall_handler.storage_read(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
StorageWrite,
storage_write,
"syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
);

define_hint_enum!(
OsHint,
(
LoadClassFacts,
load_class_facts,
Expand Down Expand Up @@ -1330,92 +1420,6 @@ memory[ap] = 1 if case != 'both' else 0"#
segments.write_arg(ids.res.address_, split(ids.value))"#
}
),
(
CallContract,
call_contract,
"syscall_handler.call_contract(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
DelegateCall,
delegate_call,
"syscall_handler.delegate_call(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
DelegateL1Handler,
delegate_l1_handler,
"syscall_handler.delegate_l1_handler(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(Deploy, deploy, "syscall_handler.deploy(segments=segments, syscall_ptr=ids.syscall_ptr)"),
(
EmitEvent,
emit_event,
"syscall_handler.emit_event(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetBlockNumber,
get_block_number,
"syscall_handler.get_block_number(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetBlockTimestamp,
get_block_timestamp,
"syscall_handler.get_block_timestamp(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetCallerAddress,
get_caller_address,
"syscall_handler.get_caller_address(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetContractAddress,
get_contract_address,
"syscall_handler.get_contract_address(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetSequencerAddress,
get_sequencer_address,
"syscall_handler.get_sequencer_address(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetTxInfo,
get_tx_info,
"syscall_handler.get_tx_info(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
GetTxSignature,
get_tx_signature,
"syscall_handler.get_tx_signature(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
LibraryCall,
library_call,
"syscall_handler.library_call(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
LibraryCallL1Handler,
library_call_l1_handler,
"syscall_handler.library_call_l1_handler(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
ReplaceClass,
replace_class,
"syscall_handler.replace_class(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
SendMessageToL1,
send_message_to_l1,
"syscall_handler.send_message_to_l1(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
StorageRead,
storage_read,
"syscall_handler.storage_read(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
StorageWrite,
storage_write,
"syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr)"
),
(
SetSyscallPtr,
set_syscall_ptr,
Expand Down
29 changes: 24 additions & 5 deletions crates/starknet_os/src/hints/enum_definition_test.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
use std::collections::HashSet;

use blockifier::execution::hint_code::SYSCALL_HINTS;
use strum::IntoEnumIterator;

use super::{Hint, HintExtension};
use super::{HintExtension, OsHint, SyscallHint};
use crate::hints::types::HintEnum;

#[test]
fn test_hint_strings_are_unique() {
let hint_strings = Hint::iter().map(|hint| hint.to_str()).collect::<Vec<_>>();
let hint_strings = OsHint::iter().map(|hint| hint.to_str()).collect::<Vec<_>>();
let hint_extension_strings =
HintExtension::iter().map(|hint| hint.to_str()).collect::<Vec<_>>();
let syscall_hint_strings = OsHint::iter().map(|hint| hint.to_str()).collect::<Vec<_>>();
let hint_strings_set: HashSet<&&str> = HashSet::from_iter(hint_strings.iter());
let hint_extension_strings_set = HashSet::from_iter(hint_extension_strings.iter());
let syscall_hint_strings_set: HashSet<&&str> = HashSet::from_iter(syscall_hint_strings.iter());
assert_eq!(hint_strings.len(), hint_strings_set.len(), "Duplicate hint strings.");
assert_eq!(
hint_extension_strings.len(),
hint_extension_strings_set.len(),
"Duplicate hint extension strings."
);
let ambiguous_strings =
hint_strings_set.intersection(&hint_extension_strings_set).collect::<Vec<_>>();
assert!(ambiguous_strings.is_empty(), "Ambiguous hint strings: {ambiguous_strings:?}");
assert_eq!(
syscall_hint_strings.len(),
syscall_hint_strings_set.len(),
"Duplicate syscall hint strings."
);

let first_intersection =
hint_strings_set.intersection(&hint_extension_strings_set).cloned().collect::<HashSet<_>>();
let mut ambiguous_strings = first_intersection.intersection(&syscall_hint_strings_set);
let common_value = ambiguous_strings.next();
assert!(common_value.is_none(), "Ambiguous hint strings: {common_value:?}");
}

#[test]
fn test_syscall_compatibility_with_blockifier() {
let syscall_hint_strings =
SyscallHint::iter().map(|hint| hint.to_str()).collect::<HashSet<_>>();
let blockifier_syscall_strings: HashSet<_> = SYSCALL_HINTS.iter().cloned().collect();
assert_eq!(blockifier_syscall_strings, syscall_hint_strings);
}

0 comments on commit a89711e

Please sign in to comment.