diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..f1dc7a3 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,52 @@ +name: test +# This is the main CI workflow that runs the test suite on all pushes to main +# and all pull requests. It runs the following jobs: +# - required: runs the test suite on ubuntu with stable and beta rust +# toolchains. +permissions: + contents: read +on: + push: + branches: [ main, release/* ] + pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true +env: + CARGO_TERM_COLOR: always +jobs: + required: + runs-on: ubuntu-latest + name: ubuntu / ${{ matrix.toolchain }} + strategy: + matrix: + # Run on stable and beta to ensure that tests won't break on the next + # version of the rust toolchain. + toolchain: [ stable, beta ] + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - name: Install rust ${{ matrix.toolchain }} + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: ${{ matrix.toolchain }} + rustflags: "" + + - name: Install nextest + uses: taiki-e/install-action@v2 + with: + tool: cargo-nextest + + - name: Cargo generate-lockfile + # Enable this ci template to run regardless of whether the lockfile is + # checked in or not. + if: hashFiles('Cargo.lock') == '' + run: cargo generate-lockfile + + - name: Run unit tests + run: cargo nextest run --locked --all-targets + + - name: Run doc tests + run: cargo test --locked --doc diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b2371e..340e73c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Mocks for the `msg::sender()` #14 +- Mocks for the `msg::value()` and `contract::balance()` #31 - Mocks for the external contract calls. Two and more contracts can be injected into test #14 - Option to inject `Account` or `Address` in the test #14 diff --git a/crates/motsu/src/context.rs b/crates/motsu/src/context.rs index c742bbc..45f7d75 100644 --- a/crates/motsu/src/context.rs +++ b/crates/motsu/src/context.rs @@ -1,37 +1,38 @@ //! Unit-testing context for Stylus contracts. -use std::{collections::HashMap, ptr, slice, thread::ThreadId}; +use std::{collections::HashMap, hash::Hash, ptr, slice, thread::ThreadId}; -use alloy_primitives::Address; +use alloy_primitives::{Address, B256, U256}; use dashmap::{mapref::one::RefMut, DashMap}; use once_cell::sync::Lazy; use stylus_sdk::{alloy_primitives::uint, prelude::StorageType, ArbResult}; use crate::{ - prelude::{Bytes32, WORD_BYTES}, - router::{RouterContext, TestRouter}, + router::{TestRouter, VMRouterContext}, + storage_access::AccessStorage, }; -/// Storage mock. +/// Motsu VM Storage. /// /// A global mutable key-value store that allows concurrent access. /// -/// The key is the test [`Context`], an id of the test thread. +/// The key is the test [`VMContext`], an id of the test thread. /// -/// The value is the [`MockStorage`], a storage of the test case. +/// The value is the [`VMContextStorage`], a storage of the test case. /// -/// NOTE: The [`DashMap`] will deadlock execution, when the same key is +/// NOTE: The [`VMContext::storage`] will panic on lock, when the same key is /// accessed twice from the same thread. -static STORAGE: Lazy> = Lazy::new(DashMap::new); +static MOTSU_VM: Lazy> = + Lazy::new(DashMap::new); -/// Context of stylus unit tests associated with the current test thread. +/// Context of Motsu test VM associated with the current test thread. #[allow(clippy::module_name_repetitions)] #[derive(Hash, Eq, PartialEq, Copy, Clone)] -pub struct Context { +pub struct VMContext { thread_id: ThreadId, } -impl Context { +impl VMContext { /// Get test context associated with the current test thread. #[must_use] pub fn current() -> Self { @@ -77,23 +78,47 @@ impl Context { .insert(key, value); } - /// Set the message sender address. - fn set_msg_sender(self, msg_sender: Address) -> Option
{ + /// Set the message sender address and return the previous sender if any. + fn replace_msg_sender(self, msg_sender: Address) -> Option
{ self.storage().msg_sender.replace(msg_sender) } /// Get the message sender address. #[must_use] - pub fn msg_sender(self) -> Option
{ + pub(crate) fn msg_sender(self) -> Option
{ self.storage().msg_sender } - /// Set the address of the contract, that is called. - fn set_contract_address(self, address: Address) -> Option
{ + /// Replace the address of the contract, and return the previous address if + /// any. + fn replace_contract_address(self, address: Address) -> Option
{ self.storage().contract_address.replace(address) } - /// Get the address of the contract, that is called. + /// Replace an optional message with `value` and return the previous value + /// if any. + /// + /// Setting `value` to `None` will effectively clear the message value, e.g. + /// for non "payable" call. + pub(crate) fn replace_optional_msg_value( + self, + value: Option, + ) -> Option { + std::mem::replace(&mut self.storage().msg_value, value) + } + + /// Write the value sent to the contract to `output`. + pub(crate) unsafe fn msg_value_raw(self, output: *mut u8) { + let value: U256 = self.msg_value().unwrap_or_default(); + write_u256(output, value); + } + + /// Get the value sent to the contract as [`U256`]. + pub(crate) fn msg_value(self) -> Option { + self.storage().msg_value + } + + /// Get the address of the contract that is called. pub(crate) fn contract_address(self) -> Option
{ self.storage().contract_address } @@ -104,7 +129,7 @@ impl Context { self, contract_address: Address, ) { - if STORAGE + if MOTSU_VM .entry(self) .or_default() .contract_data @@ -117,20 +142,20 @@ impl Context { self.router(contract_address).init_storage::(); } - /// Reset storage for the current [`Context`] and `contract_address`. + /// Reset storage for the current [`VMContext`] and `contract_address`. /// /// If all test contracts are removed, flush storage for the current - /// test [`Context`]. + /// test [`VMContext`]. fn reset_storage(self, contract_address: Address) { let mut storage = self.storage(); storage.contract_data.remove(&contract_address); // if no more contracts left, if storage.contract_data.is_empty() { - // drop guard to a concurrent hash map to avoid a deadlock, + // drop guard to a concurrent hash map to avoid a panic on lock, drop(storage); // and erase the test context. - let _ = STORAGE.remove(&self); + _ = MOTSU_VM.remove(&self); } self.router(contract_address).reset_storage(); @@ -142,46 +167,78 @@ impl Context { address: *const u8, calldata: *const u8, calldata_len: usize, - return_data_len: *mut usize, + return_data_size: *mut usize, + ) -> u8 { + let address = read_address(address); + let (selector, input) = decode_calldata(calldata, calldata_len); + + let result = self.call_contract(address, selector, &input, None); + self.process_arb_result_raw(result, return_data_size) + } + + /// Call the contract at raw `address` with the given raw `calldata` and + /// `value`. + pub(crate) unsafe fn call_contract_with_value_raw( + self, + address: *const u8, + calldata: *const u8, + calldata_len: usize, + value: *const u8, + return_data_size: *mut usize, ) -> u8 { - let address_bytes = slice::from_raw_parts(address, 20); - let address = Address::from_slice(address_bytes); + let address = read_address(address); + let value = read_u256(value); + let (selector, input) = decode_calldata(calldata, calldata_len); - let input = slice::from_raw_parts(calldata, calldata_len); - let selector = - u32::from_be_bytes(TryInto::try_into(&input[..4]).unwrap()); + let result = self.call_contract(address, selector, &input, Some(value)); + self.process_arb_result_raw(result, return_data_size) + } - match self.call_contract(address, selector, &input[4..]) { + /// Based on `result`, set the return data. + /// Return 0 if `result` is `Ok`, otherwise 1. + unsafe fn process_arb_result_raw( + self, + result: ArbResult, + return_data_size: *mut usize, + ) -> u8 { + match result { Ok(res) => { - return_data_len.write(res.len()); + return_data_size.write(res.len()); self.set_return_data(res); 0 } Err(err) => { - return_data_len.write(err.len()); + return_data_size.write(err.len()); self.set_return_data(err); 1 } } } - /// Call the function associated with the given `selector` and pass `input` - /// to it, at the given `contract_address`. + /// Call the function associated with the given `selector` at the given + /// `contract_address`. Pass `input` and optional `value` to it. fn call_contract( self, contract_address: Address, selector: u32, input: &[u8], + value: Option, ) -> ArbResult { // Set the caller contract as message sender and callee contract as // a receiver (`contract_address`). let previous_contract_address = self - .set_contract_address(contract_address) + .replace_contract_address(contract_address) .expect("contract_address should be set"); let previous_msg_sender = self - .set_msg_sender(previous_contract_address) + .replace_msg_sender(previous_contract_address) .expect("msg_sender should be set"); + // Set new msg_value, and store the previous one. + let previous_msg_value = self.replace_optional_msg_value(value); + + // Transfer value sent by message sender. + self.transfer_value(); + // Call external contract. let result = self .router(contract_address) @@ -191,8 +248,11 @@ impl Context { }); // Set the previous message sender and contract address back. - let _ = self.set_contract_address(previous_contract_address); - let _ = self.set_msg_sender(previous_msg_sender); + _ = self.replace_contract_address(previous_contract_address); + _ = self.replace_msg_sender(previous_msg_sender); + + // Set the previous msg_value. + self.replace_optional_msg_value(previous_msg_value); result } @@ -200,8 +260,8 @@ impl Context { /// Set return data as bytes. fn set_return_data(self, data: Vec) { let mut call_storage = self.storage(); - let _ = call_storage.call_output_len.insert(data.len()); - let _ = call_storage.call_output.insert(data); + _ = call_storage.return_data_size.insert(data.len()); + _ = call_storage.return_data.insert(data); } /// Read the return data (with a given `size`) from the last contract call @@ -219,20 +279,19 @@ impl Context { /// Return data's size in bytes from the last contract call. pub(crate) fn return_data_size(self) -> usize { self.storage() - .call_output_len + .return_data_size .take() .expect("call_output_len should be set") } /// Return data's bytes from the last contract call. fn return_data(self) -> Vec { - self.storage().call_output.take().expect("call_output should be set") + self.storage().return_data.take().expect("call_output should be set") } /// Check if the contract at raw `address` has code. pub(crate) unsafe fn has_code_raw(self, address: *const u8) -> bool { - let address_bytes = slice::from_raw_parts(address, 20); - let address = Address::from_slice(address_bytes); + let address = read_address(address); self.has_code(address) } @@ -242,51 +301,178 @@ impl Context { self.router(address).exists() } + /// Get the balance of account at `address`. + pub(crate) fn balance(self, address: Address) -> U256 { + self.storage().balances.get(&address).copied().unwrap_or_default() + } + + /// Transfer value from the message sender to the contract. + /// No-op if `msg_sender` or `contract_address` weren't set. + /// + /// # Panics + /// + /// * If there is not enough funds to transfer. + fn transfer_value(self) { + let storage = self.storage(); + let Some(msg_sender) = storage.msg_sender else { + return; + }; + let Some(contract_address) = storage.contract_address else { + return; + }; + + // We should transfer the value only if it is set. + if let Some(msg_value) = storage.msg_value { + // Drop storage to avoid a panic on lock. + drop(storage); + + // Transfer and panic if there is not enough funds. + self.transfer(msg_sender, contract_address, msg_value); + } + } + + /// Transfer `value` from `from` account to `to` account. + /// + /// # Panics + /// + /// * If there is not enough funds to transfer. + fn transfer(self, from: Address, to: Address, value: U256) { + // Transfer and panic if there is not enough funds. + self.checked_transfer(from, to, value) + .unwrap_or_else(|| panic!("{from} account should have enough funds to transfer {value} value")); + } + + /// Transfer `value` from `from` account to `to` account. + /// + /// Returns `None` if there is not enough funds to transfer. + fn checked_transfer( + self, + from: Address, + to: Address, + value: U256, + ) -> Option<()> { + self.checked_sub_assign_balance(from, value)?; + self.add_assign_balance(to, value); + Some(()) + } + + /// Subtract `value` from the balance of `address` account. + /// + /// Returns `None` if there is not enough of funds. + fn checked_sub_assign_balance( + self, + address: Address, + value: U256, + ) -> Option { + let mut storage = self.storage(); + let balance = storage.balances.entry(address).or_default(); + if *balance < value { + return None; + } + *balance -= value; + Some(*balance) + } + + /// Add `value` to the balance of `address` account. + fn add_assign_balance(self, address: Address, value: U256) -> U256 { + *self + .storage() + .balances + .entry(address) + .and_modify(|v| *v += value) + .or_insert(value) + } + /// Get reference to the storage for the current test thread. - fn storage(self) -> RefMut<'static, Context, MockStorage> { - STORAGE.get_mut(&self).expect("contract should be initialised first") + fn storage(self) -> RefMut<'static, VMContext, VMContextStorage> { + MOTSU_VM.access_storage(&self) } /// Get router for the contract at `address`. - fn router(self, address: Address) -> RouterContext { - RouterContext::new(self.thread_id, address) + fn router(self, address: Address) -> VMRouterContext { + VMRouterContext::new(self.thread_id, address) } } /// Read the word from location pointed by `ptr`. -unsafe fn read_bytes32(ptr: *const u8) -> Bytes32 { +pub(crate) unsafe fn read_bytes32(ptr: *const u8) -> Bytes32 { let mut res = Bytes32::default(); ptr::copy(ptr, res.as_mut_ptr(), WORD_BYTES); res } /// Write the word `bytes` to the location pointed by `ptr`. -unsafe fn write_bytes32(ptr: *mut u8, bytes: Bytes32) { +pub(crate) unsafe fn write_bytes32(ptr: *mut u8, bytes: Bytes32) { ptr::copy(bytes.as_ptr(), ptr, WORD_BYTES); } +/// Read the [`Address`] from the raw pointer. +pub(crate) unsafe fn read_address(ptr: *const u8) -> Address { + let address_bytes = slice::from_raw_parts(ptr, 20); + Address::from_slice(address_bytes) +} + +/// Write the [`Address`] `address` to the location pointed by `ptr`. +pub(crate) unsafe fn write_address(ptr: *mut u8, address: Address) { + ptr::copy(address.as_ptr(), ptr, 20); +} + +/// Read the [`U256`] from the raw pointer. +pub(crate) unsafe fn read_u256(ptr: *const u8) -> U256 { + let mut data = B256::ZERO; + ptr::copy(ptr, data.as_mut_ptr(), 32); + data.into() +} + +/// Write the [`U256`] `value` to the location pointed by `ptr`. +pub(crate) unsafe fn write_u256(ptr: *mut u8, value: U256) { + let bytes: B256 = value.into(); + ptr::copy(bytes.as_ptr(), ptr, 32); +} + +/// Decode the selector as [`u32`], and function input as [`Vec`] from the +/// raw pointer. +unsafe fn decode_calldata( + calldata: *const u8, + calldata_len: usize, +) -> (u32, Vec) { + let calldata = slice::from_raw_parts(calldata, calldata_len); + let selector = + u32::from_be_bytes(TryInto::try_into(&calldata[..4]).unwrap()); + let input = calldata[4..].to_vec(); + (selector, input) +} + /// Storage for unit test's mock data. #[derive(Default)] -struct MockStorage { +struct VMContextStorage { /// Address of the message sender. msg_sender: Option
, + /// The ETH value in wei sent to the program. + msg_value: Option, /// Address of the contract that is currently receiving the message. contract_address: Option
, /// Contract's address to mock data storage mapping. contract_data: HashMap, + /// Account's address to balance mapping. + balances: HashMap, // Output of a contract call. - call_output: Option>, + return_data: Option>, // Output length of a contract call. - call_output_len: Option, + return_data_size: Option, } +/// Contract's byte storage type ContractStorage = HashMap; +pub(crate) const WORD_BYTES: usize = 32; +pub(crate) type Bytes32 = [u8; WORD_BYTES]; /// Contract call entity, related to the contract type `ST` and the caller's /// account. pub struct ContractCall<'a, ST: StorageType> { storage: ST, - caller_address: Address, + msg_sender: Address, + msg_value: Option, /// We need to hold a reference to [`Contract`], because /// `Contract::::new().sender(alice)` can accidentally drop /// [`Contract`]. @@ -297,15 +483,12 @@ pub struct ContractCall<'a, ST: StorageType> { } impl ContractCall<'_, ST> { - /// Get the contract's address. - pub fn address(&self) -> Address { - self.contract_ref.address - } - /// Preset the call parameters. fn set_call_params(&self) { - let _ = Context::current().set_msg_sender(self.caller_address); - let _ = Context::current().set_contract_address(self.address()); + _ = VMContext::current().replace_optional_msg_value(self.msg_value); + _ = VMContext::current().replace_msg_sender(self.msg_sender); + _ = VMContext::current() + .replace_contract_address(self.contract_ref.address); } } @@ -315,6 +498,7 @@ impl ::core::ops::Deref for ContractCall<'_, ST> { #[inline] fn deref(&self) -> &Self::Target { self.set_call_params(); + VMContext::current().transfer_value(); &self.storage } } @@ -323,6 +507,7 @@ impl ::core::ops::DerefMut for ContractCall<'_, ST> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { self.set_call_params(); + VMContext::current().transfer_value(); &mut self.storage } } @@ -335,7 +520,7 @@ pub struct Contract { impl Drop for Contract { fn drop(&mut self) { - Context::current().reset_storage(self.address); + VMContext::current().reset_storage(self.address); } } @@ -355,7 +540,7 @@ impl Contract { /// Create a new contract with the given `address`. #[must_use] pub fn new_at(address: Address) -> Self { - Context::current().init_storage::(address); + VMContext::current().init_storage::(address); Self { phantom: ::core::marker::PhantomData, address } } @@ -364,10 +549,10 @@ impl Contract { /// the given `account`. pub fn init, Output>( &self, - account: A, + sender: A, initializer: impl FnOnce(&mut ST) -> Output, ) -> Output { - initializer(&mut self.sender(account.into())) + initializer(&mut self.sender(sender.into())) } /// Create a new contract with default storage on the random address. @@ -387,7 +572,26 @@ impl Contract { pub fn sender>(&self, account: A) -> ContractCall { ContractCall { storage: unsafe { ST::new(uint!(0_U256), 0) }, - caller_address: account.into(), + msg_sender: account.into(), + msg_value: None, + contract_ref: self, + } + } + + /// Call contract `self` with `account` as a sender and `value`. + #[must_use] + pub fn sender_and_value, V: Into>( + &self, + sender: A, + value: V, + ) -> ContractCall { + let caller_address = sender.into(); + let value = value.into(); + + ContractCall { + storage: unsafe { ST::new(uint!(0_U256), 0) }, + msg_sender: caller_address, + msg_value: Some(value), contract_ref: self, } } @@ -424,3 +628,42 @@ impl Account { self.address } } + +/// Fund the account. +pub trait Funding { + /// Fund the account with the given `value`. + fn fund(&self, value: U256); + + /// Get the balance of the account. + fn balance(&self) -> U256; +} + +impl Funding for Address { + fn fund(&self, value: U256) { + VMContext::current().add_assign_balance(*self, value); + } + + fn balance(&self) -> U256 { + VMContext::current().balance(*self) + } +} + +impl Funding for Account { + fn fund(&self, value: U256) { + self.address().fund(value); + } + + fn balance(&self) -> U256 { + self.address().balance() + } +} + +impl Funding for Contract { + fn fund(&self, value: U256) { + self.address().fund(value); + } + + fn balance(&self) -> U256 { + self.address().balance() + } +} diff --git a/crates/motsu/src/lib.rs b/crates/motsu/src/lib.rs index 620f5a6..f348b84 100644 --- a/crates/motsu/src/lib.rs +++ b/crates/motsu/src/lib.rs @@ -54,6 +54,7 @@ mod context; pub mod prelude; mod router; mod shims; +mod storage_access; pub use motsu_proc::test; @@ -87,7 +88,7 @@ mod ping_pong_tests { let value = receiver.pong(call, value)?; let pings_count = self.pings_count.get(); - self.pings_count.set(pings_count + uint!(1_U256)); + self.pings_count.set(pings_count + ONE); self.pinged_from.set(msg::sender()); self.contract_address.set(contract::address()); @@ -131,6 +132,9 @@ mod ping_pong_tests { const MAGIC_ERROR_VALUE: U256 = uint!(42_U256); + const ONE: U256 = uint!(1_U256); + const TEN: U256 = uint!(10_U256); + #[storage] struct PongContract { pongs_count: StorageU256, @@ -146,12 +150,12 @@ mod ping_pong_tests { } let pongs_count = self.pongs_count.get(); - self.pongs_count.set(pongs_count + uint!(1_U256)); + self.pongs_count.set(pongs_count + ONE); self.ponged_from.set(msg::sender()); self.contract_address.set(contract::address()); - Ok(value + uint!(1_U256)) + Ok(value + ONE) } fn can_pong(&self) -> bool { @@ -167,15 +171,15 @@ mod ping_pong_tests { pong: Contract, alice: Account, ) { - let value = uint!(10_U256); + let value = TEN; let ponged_value = ping .sender(alice) .ping(pong.address(), value) .expect("should ping successfully"); - assert_eq!(ponged_value, value + uint!(1_U256)); - assert_eq!(ping.sender(alice).pings_count.get(), uint!(1_U256)); - assert_eq!(pong.sender(alice).pongs_count.get(), uint!(1_U256)); + assert_eq!(ponged_value, value + ONE); + assert_eq!(ping.sender(alice).pings_count.get(), ONE); + assert_eq!(pong.sender(alice).pongs_count.get(), ONE); } #[motsu_proc::test] @@ -215,9 +219,9 @@ mod ping_pong_tests { assert_eq!(ping.sender(alice).pinged_from.get(), Address::ZERO); assert_eq!(pong.sender(alice).ponged_from.get(), Address::ZERO); - let _ = ping + _ = ping .sender(alice) - .ping(pong.address(), uint!(10_U256)) + .ping(pong.address(), TEN) .expect("should ping successfully"); assert_eq!(ping.sender(alice).pinged_from.get(), alice.address()); @@ -242,9 +246,9 @@ mod ping_pong_tests { assert_eq!(ping.sender(alice).contract_address.get(), Address::ZERO); assert_eq!(pong.sender(alice).contract_address.get(), Address::ZERO); - let _ = ping + _ = ping .sender(alice) - .ping(pong.address(), uint!(10_U256)) + .ping(pong.address(), TEN) .expect("should ping successfully"); assert_eq!(ping.sender(alice).contract_address.get(), ping.address()); @@ -257,10 +261,11 @@ mod ping_pong_tests { let ping = Contract::::new(); let mut ping = ping.sender(alice); let pong = Contract::::new(); + let pong_address = pong.address(); let pong = pong.sender(alice); - let _ = ping - .ping(pong.address(), uint!(10_U256)) + _ = ping + .ping(pong_address, TEN) .expect("contract ping should not drop"); } } @@ -270,16 +275,23 @@ mod proxies_tests { use alloy_primitives::{uint, Address, U256}; use stylus_sdk::{ call::Call, + contract, msg, prelude::{public, storage, TopLevelStorage}, storage::StorageAddress, }; - use crate::context::{Account, Contract}; + use crate::prelude::*; stylus_sdk::stylus_proc::sol_interface! { interface IProxy { #[allow(missing_docs)] function callProxy(uint256 value) external returns (uint256); + #[allow(missing_docs)] + function payProxy() external payable; + #[allow(missing_docs)] + function passProxyWithFixedValue(uint256 pass_value) external payable; + #[allow(missing_docs)] + function payProxyWithHalfBalance() external payable; } } @@ -294,7 +306,7 @@ mod proxies_tests { let next_proxy = self.next_proxy.get(); // Add one to the value. - let value = value + uint!(1_U256); + let value = value + ONE; // If there is no next proxy, return the value. if next_proxy.is_zero() { @@ -306,12 +318,76 @@ mod proxies_tests { proxy.call_proxy(call, value).expect("should call proxy") } } + + #[payable] + fn pay_proxy(&mut self) { + let next_proxy = self.next_proxy.get(); + + // If there is a next proxy. + if !next_proxy.is_zero() { + // Add one to the message value. + let value = msg::value() + ONE; + + // Pay the next proxy. + let proxy = IProxy::new(next_proxy); + let call = Call::new_in(self).value(value); + proxy.pay_proxy(call).expect("should pay proxy"); + } + } + + #[payable] + fn pass_proxy_with_fixed_value(&mut self, this_value: U256) { + let next_proxy = self.next_proxy.get(); + + // If there is a next proxy. + if !next_proxy.is_zero() { + // Pay the next proxy. + let proxy = IProxy::new(next_proxy); + let call = Call::new_in(self).value(this_value); + let value_for_next_next_proxy = this_value / TWO; + proxy + .pass_proxy_with_fixed_value( + call, + value_for_next_next_proxy, + ) + .expect("should pass half the value to the next proxy"); + } + } + + #[payable] + fn pay_proxy_with_half_balance(&mut self) { + let next_proxy = self.next_proxy.get(); + + // If there is a next proxy. + if !next_proxy.is_zero() { + let half_balance = contract::balance() / TWO; + // Pay the next proxy. + let proxy = IProxy::new(next_proxy); + let call = Call::new_in(self).value(half_balance); + proxy + .pay_proxy_with_half_balance(call) + .expect("should pass half the value to the next proxy"); + } + } + } + + impl Proxy { + fn init(&mut self, next_proxy: Address) { + self.next_proxy.set(next_proxy); + } } unsafe impl TopLevelStorage for Proxy {} + const ONE: U256 = uint!(1_U256); + const TWO: U256 = uint!(2_U256); + const FOUR: U256 = uint!(4_U256); + const EIGHT: U256 = uint!(8_U256); + + const TEN: U256 = uint!(10_U256); + #[motsu_proc::test] - fn three_proxies( + fn call_three_proxies( proxy1: Contract, proxy2: Contract, proxy3: Contract, @@ -319,21 +395,125 @@ mod proxies_tests { ) { // Set up a chain of three proxies. // With the given call chain: proxy1 -> proxy2 -> proxy3. - proxy1.init(alice, |storage| { - storage.next_proxy.set(proxy2.address()); - }); - proxy2.init(alice, |storage| { - storage.next_proxy.set(proxy3.address()); - }); - proxy3.init(alice, |storage| { - storage.next_proxy.set(Address::ZERO); - }); + proxy1.sender(alice).init(proxy2.address()); + proxy2.sender(alice).init(proxy3.address()); + proxy3.sender(alice).init(Address::ZERO); // Call the first proxy. - let value = uint!(10_U256); - let result = proxy1.sender(alice).call_proxy(value); + let result = proxy1.sender(alice).call_proxy(TEN); // The value is incremented by 1 for each proxy. - assert_eq!(result, value + uint!(3_U256)); + assert_eq!(result, TEN + ONE + ONE + ONE); + } + + #[motsu_proc::test] + fn pay_three_proxies( + proxy1: Contract, + proxy2: Contract, + proxy3: Contract, + alice: Account, + ) { + // Set up a chain of three proxies. + // With the given call chain: proxy1 -> proxy2 -> proxy3. + proxy1.sender(alice).init(proxy2.address()); + proxy2.sender(alice).init(proxy3.address()); + proxy3.sender(alice).init(Address::ZERO); + + // Fund accounts. + alice.fund(TEN); + proxy1.fund(TEN); + proxy2.fund(TEN); + proxy3.fund(TEN); + + // Call the first proxy. + proxy1.sender_and_value(alice, ONE).pay_proxy(); + + // By the end, each actor will lose `ONE`, except last proxy. + assert_eq!(alice.balance(), TEN - ONE); + assert_eq!(proxy1.balance(), TEN - ONE); + assert_eq!(proxy2.balance(), TEN - ONE); + assert_eq!(proxy3.balance(), TEN + ONE + ONE + ONE); + } + + #[motsu_proc::test] + fn pass_proxy_with_fixed_value( + proxy1: Contract, + proxy2: Contract, + proxy3: Contract, + alice: Account, + ) { + // Set up a chain of three proxies. + // With the given call chain: proxy1 -> proxy2 -> proxy3. + proxy1.sender(alice).init(proxy2.address()); + proxy2.sender(alice).init(proxy3.address()); + proxy3.sender(alice).init(Address::ZERO); + + // Fund alice, proxies have no funds. + alice.fund(EIGHT); + + assert_eq!(alice.balance(), EIGHT); + assert_eq!(proxy1.balance(), U256::ZERO); + assert_eq!(proxy2.balance(), U256::ZERO); + assert_eq!(proxy3.balance(), U256::ZERO); + + // Call the first proxy. + proxy1.sender_and_value(alice, EIGHT).pass_proxy_with_fixed_value(FOUR); + + assert_eq!(alice.balance(), U256::ZERO); + assert_eq!(proxy1.balance(), FOUR); + assert_eq!(proxy2.balance(), TWO); + assert_eq!(proxy3.balance(), TWO); + } + + #[motsu_proc::test] + fn pay_proxy_with_half_balance( + proxy1: Contract, + proxy2: Contract, + proxy3: Contract, + alice: Account, + ) { + // Set up a chain of three proxies. + // With the given call chain: proxy1 -> proxy2 -> proxy3. + proxy1.sender(alice).init(proxy2.address()); + proxy2.sender(alice).init(proxy3.address()); + proxy3.sender(alice).init(Address::ZERO); + + // Fund alice, proxies have no funds. + alice.fund(EIGHT); + + assert_eq!(alice.balance(), EIGHT); + assert_eq!(proxy1.balance(), U256::ZERO); + assert_eq!(proxy2.balance(), U256::ZERO); + assert_eq!(proxy3.balance(), U256::ZERO); + + // Call the first proxy. + proxy1.sender_and_value(alice, EIGHT).pay_proxy_with_half_balance(); + + assert_eq!(alice.balance(), U256::ZERO); + assert_eq!(proxy1.balance(), FOUR); + assert_eq!(proxy2.balance(), TWO); + assert_eq!(proxy3.balance(), TWO); + } + + #[motsu_proc::test] + fn no_locks_with_panics() { + for _ in 0..1000 { + let proxy1 = Contract::::new(); + let proxy2 = Contract::::new(); + let proxy3 = Contract::::new(); + let alice = Account::random(); + + // Set up a chain of three proxies. + // With the given call chain: proxy1 -> proxy2 -> proxy3. + proxy1.sender(alice).init(proxy2.address()); + proxy2.sender(alice).init(proxy3.address()); + proxy3.sender(alice).init(Address::ZERO); + + // Call the first proxy. + let result = proxy1.sender(alice).call_proxy(TEN); + + // The value is incremented by 1 for each proxy. + assert_eq!(result, TEN + ONE + ONE + ONE); + } } } diff --git a/crates/motsu/src/prelude.rs b/crates/motsu/src/prelude.rs index a881a05..e52a05a 100644 --- a/crates/motsu/src/prelude.rs +++ b/crates/motsu/src/prelude.rs @@ -1,5 +1,2 @@ //! Common imports for `motsu` tests. -pub use crate::{ - context::{Account, Context, Contract, ContractCall}, - shims::*, -}; +pub use crate::context::{Account, Contract, ContractCall, Funding, VMContext}; diff --git a/crates/motsu/src/router.rs b/crates/motsu/src/router.rs index f12ea20..14cc6fd 100644 --- a/crates/motsu/src/router.rs +++ b/crates/motsu/src/router.rs @@ -1,9 +1,10 @@ //! Router context for external calls mocks. -//! -//! NOTE: [`ROUTER_STORAGE`] should be separated from the main test storage to -//! avoid deadlocks. -use std::{borrow::BorrowMut, sync::Mutex, thread::ThreadId}; +use std::{ + borrow::BorrowMut, + sync::{Arc, Mutex, TryLockError}, + thread::ThreadId, +}; use alloy_primitives::{uint, Address}; use dashmap::{mapref::one::RefMut, DashMap}; @@ -14,46 +15,46 @@ use stylus_sdk::{ ArbResult, }; -/// Router Storage. +use crate::storage_access::AccessStorage; + +/// Motsu VM Router Storage. /// /// A global mutable key-value store that allows concurrent access. /// -/// The key is the [`RouterContext`], a combination of [`ThreadId`] and -/// [`Address`] to avoid a deadlock, while calling more than two contracts +/// The key is the [`VMRouterContext`], a combination of [`ThreadId`] and +/// [`Address`] to avoid a panic on lock, while calling more than two contracts /// consecutive. /// -/// The value is the [`RouterStorage`], a router of the contract generated by +/// The value is the [`VMRouterStorage`], a router of the contract generated by /// `stylus-sdk`. /// -/// NOTE: The [`DashMap`] will deadlock execution, when the same key is -/// accessed twice from the same thread. -static ROUTER_STORAGE: Lazy> = +/// NOTE: The [`VMRouterContext::storage`] will panic on lock, when the same key +/// is accessed twice from the same thread. +static MOTSU_VM_ROUTERS: Lazy> = Lazy::new(DashMap::new); -/// Context for the router of a test contract for current test thread and +/// Context of Motsu test VM router associated with the current test thread and /// contract's address. #[derive(Hash, Eq, PartialEq, Copy, Clone)] -pub(crate) struct RouterContext { +pub(crate) struct VMRouterContext { thread_id: ThreadId, contract_address: Address, } -impl RouterContext { +impl VMRouterContext { /// Create a new router context. pub(crate) fn new(thread: ThreadId, contract_address: Address) -> Self { Self { thread_id: thread, contract_address } } /// Get reference to the call storage for the current test thread. - fn storage(self) -> RefMut<'static, RouterContext, RouterStorage> { - ROUTER_STORAGE - .get_mut(&self) - .expect("contract should be initialised first") + fn storage(self) -> RefMut<'static, VMRouterContext, VMRouterStorage> { + MOTSU_VM_ROUTERS.access_storage(&self) } /// Check if the router exists for the contract. pub(crate) fn exists(self) -> bool { - ROUTER_STORAGE.contains_key(&self) + MOTSU_VM_ROUTERS.contains_key(&self) } pub(crate) fn route( @@ -61,20 +62,38 @@ impl RouterContext { selector: u32, input: &[u8], ) -> Option { - let router = &self.storage().router; - let mut router = router.lock().expect("should lock test router"); - router.route(selector, input) + let storage = self.storage(); + let router = Arc::clone(&storage.router); + + // Drop the storage reference to avoid a panic on lock. + drop(storage); + + // Try to get lock on the router. + let lock_result = router.try_lock(); + match lock_result { + // If lock is acquired, route the message. + Ok(mut router) => router.route(selector, input), + // Panic instead of locking. + Err(TryLockError::WouldBlock) => { + panic!("recursive calls are not supported in motsu") + } + // This branch should not be reached, since we don't catch + // panics. + Err(TryLockError::Poisoned(_)) => { + panic!("should not call contract that panicked") + } + } } /// Initialise contract router for the current test thread and /// `contract_address`. pub(crate) fn init_storage(self) { let contract_address = self.contract_address; - if ROUTER_STORAGE + if MOTSU_VM_ROUTERS .insert( self, - RouterStorage { - router: Mutex::new(Box::new(unsafe { + VMRouterStorage { + router: Arc::new(Mutex::new(unsafe { ST::new(uint!(0_U256), 0) })), }, @@ -85,17 +104,17 @@ impl RouterContext { } } - /// Reset router storage for the current [`RouterContext`]. + /// Reset router storage for the current [`VMRouterContext`]. pub(crate) fn reset_storage(self) { - ROUTER_STORAGE.remove(&self); + MOTSU_VM_ROUTERS.remove(&self); } } /// Metadata related to the router of an external contract. -struct RouterStorage { +struct VMRouterStorage { // Contract's router. // NOTE: Mutex is important since contract type is not `Sync`. - router: Mutex>, + router: Arc>, } /// A trait for routing messages to the appropriate selector in tests. diff --git a/crates/motsu/src/shims.rs b/crates/motsu/src/shims.rs index 95c639a..1bded9a 100644 --- a/crates/motsu/src/shims.rs +++ b/crates/motsu/src/shims.rs @@ -8,7 +8,7 @@ //! //! ## Motivation //! -//! Without these shims we can't currently run unit tests for stylus contracts, +//! Without these shims, we can't currently run unit tests for stylus contracts, //! since the symbols the compiled binaries expect to find are not there. //! //! If you run `cargo test` on a fresh Stylus project, it will error with: @@ -16,36 +16,28 @@ //! ```terminal //! dyld[97792]: missing symbol called //! ``` -//! -//! This crate is a temporary solution until the Stylus team provides us with a -//! different and more stable mechanism for unit-testing our contracts. -//! -//! ## Usage -//! -//! Import these shims in your test modules as `motsu::prelude::*` to populate -//! the namespace with the appropriate symbols. -//! -//! ```rust,ignore -//! #[cfg(test)] -//! mod tests { -//! use contracts::token::erc20::Erc20; -//! -//! #[motsu::test] -//! fn reads_balance(contract: Erc20) { -//! let balance = contract.balance_of(Address::ZERO); // Access storage. -//! assert_eq!(balance, U256::ZERO); -//! } -//! } -//! ``` +#![allow(dead_code)] #![allow(clippy::missing_safety_doc)] use std::slice; use tiny_keccak::{Hasher, Keccak}; -use crate::context::Context; +use crate::context::{ + read_address, write_address, write_bytes32, write_u256, VMContext, + WORD_BYTES, +}; -pub(crate) const WORD_BYTES: usize = 32; -pub(crate) type Bytes32 = [u8; WORD_BYTES]; +/// Arbitrum's CHAID ID. +const CHAIN_ID: u64 = 42161; + +/// Externally Owned Account (EOA) code hash (wallet account). +const EOA_CODEHASH: &[u8; 66] = + b"0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"; + +/// Contract Account (CA) code hash (smart contract code). +/// NOTE: can be any 256-bit value to pass `has_code` check. +const CA_CODEHASH: &[u8; 66] = + b"0x1111111111111111111111111111111111111111111111111111111111111111"; /// Efficiently computes the [`keccak256`] hash of the given preimage. /// The semantics are equivalent to that of the EVM's [`SHA3`] opcode. @@ -53,7 +45,7 @@ pub(crate) type Bytes32 = [u8; WORD_BYTES]; /// [`keccak256`]: https://en.wikipedia.org/wiki/SHA-3 /// [`SHA3`]: https://www.evm.codes/#20 #[no_mangle] -pub unsafe extern "C" fn native_keccak256( +unsafe extern "C" fn native_keccak256( bytes: *const u8, len: usize, output: *mut u8, @@ -80,8 +72,8 @@ pub unsafe extern "C" fn native_keccak256( /// /// May panic if unable to lock `STORAGE`. #[no_mangle] -pub unsafe extern "C" fn storage_load_bytes32(key: *const u8, out: *mut u8) { - Context::current().get_bytes_raw(key, out); +unsafe extern "C" fn storage_load_bytes32(key: *const u8, out: *mut u8) { + VMContext::current().get_bytes_raw(key, out); } /// Writes a 32-byte value to the permanent storage cache. @@ -100,11 +92,8 @@ pub unsafe extern "C" fn storage_load_bytes32(key: *const u8, out: *mut u8) { /// /// May panic if unable to lock `STORAGE`. #[no_mangle] -pub unsafe extern "C" fn storage_cache_bytes32( - key: *const u8, - value: *const u8, -) { - Context::current().set_bytes_raw(key, value); +unsafe extern "C" fn storage_cache_bytes32(key: *const u8, value: *const u8) { + VMContext::current().set_bytes_raw(key, value); } /// Persists any dirty values in the storage cache to the EVM state trie, @@ -113,26 +102,15 @@ pub unsafe extern "C" fn storage_cache_bytes32( /// /// [`SSTORE`]: https://www.evm.codes/#55 #[no_mangle] -pub unsafe extern "C" fn storage_flush_cache(_: bool) { +unsafe extern "C" fn storage_flush_cache(_: bool) { // No-op: we don't use the cache in our unit-tests. } -/// Arbitrum's CHAID ID. -pub const CHAIN_ID: u64 = 42161; - -/// Externally Owned Account (EOA) code hash (wallet account). -pub const EOA_CODEHASH: &[u8; 66] = - b"0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"; - -/// Contract Account (CA) code hash (smart contract code). -/// NOTE: can be any 256-bit value to pass `has_code` check. -pub const CA_CODEHASH: &[u8; 66] = - b"0x1111111111111111111111111111111111111111111111111111111111111111"; - /// Gets the address of the account that called the program. /// -/// For normal L2-to-L2 transactions the semantics are equivalent to that of the -/// EVM's [`CALLER`] opcode, including in cases arising from [`DELEGATE_CALL`]. +/// For normal L2-to-L2 transactions, the semantics are equivalent to that of +/// the EVM's [`CALLER`] opcode, including in cases arising from +/// [`DELEGATE_CALL`]. /// /// For L1-to-L2 retryable ticket transactions, the top-level sender's address /// will be aliased. See [`Retryable Ticket Address Aliasing`][aliasing] for @@ -146,17 +124,16 @@ pub const CA_CODEHASH: &[u8; 66] = /// /// May panic if fails to parse `MSG_SENDER` as an address. #[no_mangle] -pub unsafe extern "C" fn msg_sender(sender: *mut u8) { +unsafe extern "C" fn msg_sender(sender: *mut u8) { let msg_sender = - Context::current().msg_sender().expect("msg_sender should be set"); - std::ptr::copy(msg_sender.as_ptr(), sender, 20); + VMContext::current().msg_sender().expect("msg_sender should be set"); + write_address(sender, msg_sender); } /// Get the ETH value (U256) in wei sent to the program. #[no_mangle] -pub unsafe extern "C" fn msg_value(value: *mut u8) { - let dummy_msg_value: Bytes32 = Bytes32::default(); - std::ptr::copy(dummy_msg_value.as_ptr(), value, 32); +unsafe extern "C" fn msg_value(value: *mut u8) { + VMContext::current().msg_value_raw(value); } /// Gets the address of the current program. The semantics are equivalent to @@ -168,11 +145,11 @@ pub unsafe extern "C" fn msg_value(value: *mut u8) { /// /// May panic if fails to parse `CONTRACT_ADDRESS` as an address. #[no_mangle] -pub unsafe extern "C" fn contract_address(address: *mut u8) { - let contract_address = Context::current() +unsafe extern "C" fn contract_address(address: *mut u8) { + let contract_address = VMContext::current() .contract_address() .expect("contract_address should be set"); - std::ptr::copy(contract_address.as_ptr(), address, 20); + write_address(address, contract_address); } /// Gets the chain ID of the current chain. The semantics are equivalent to @@ -180,7 +157,7 @@ pub unsafe extern "C" fn contract_address(address: *mut u8) { /// /// [`CHAINID`]: https://www.evm.codes/#46 #[no_mangle] -pub unsafe extern "C" fn chainid() -> u64 { +unsafe extern "C" fn chainid() -> u64 { CHAIN_ID } @@ -197,7 +174,7 @@ pub unsafe extern "C" fn chainid() -> u64 { /// [`LOG3`]: https://www.evm.codes/#a3 /// [`LOG4`]: https://www.evm.codes/#a4 #[no_mangle] -pub unsafe extern "C" fn emit_log(_: *const u8, _: usize, _: usize) { +unsafe extern "C" fn emit_log(_: *const u8, _: usize, _: usize) { // No-op: we don't check for events in our unit-tests. } @@ -214,29 +191,40 @@ pub unsafe extern "C" fn emit_log(_: *const u8, _: usize, _: usize) { /// /// May panic if fails to parse `ACCOUNT_CODEHASH` as a keccack hash. #[no_mangle] -pub unsafe extern "C" fn account_codehash(address: *const u8, dest: *mut u8) { - let code_hash = if Context::current().has_code_raw(address) { +unsafe extern "C" fn account_codehash(address: *const u8, dest: *mut u8) { + let code_hash = if VMContext::current().has_code_raw(address) { CA_CODEHASH } else { EOA_CODEHASH }; let account_codehash = - const_hex::const_decode_to_array::<32>(code_hash).unwrap(); + const_hex::const_decode_to_array::(code_hash).unwrap(); - std::ptr::copy(account_codehash.as_ptr(), dest, 32); + write_bytes32(dest, account_codehash); +} + +/// Gets the ETH balance in wei of the account at the given address. +/// The semantics are equivalent to that of the EVM's [`BALANCE`] opcode. +/// +/// [`BALANCE`]: https://www.evm.codes/#31 +#[no_mangle] +unsafe extern "C" fn account_balance(address: *const u8, dest: *mut u8) { + let address = read_address(address); + let balance = VMContext::current().balance(address); + write_u256(dest, balance); } /// Returns the length of the last EVM call or deployment return result, or `0` -/// if neither have happened during the program's execution. +/// if neither has happened during the program's execution. /// /// The semantics are equivalent to that of the EVM's [`RETURN_DATA_SIZE`] /// opcode. /// /// [`RETURN_DATA_SIZE`]: https://www.evm.codes/#3d #[no_mangle] -pub unsafe extern "C" fn return_data_size() -> usize { - Context::current().return_data_size() +unsafe extern "C" fn return_data_size() -> usize { + VMContext::current().return_data_size() } /// Copies the bytes of the last EVM call or deployment return result. @@ -249,41 +237,42 @@ pub unsafe extern "C" fn return_data_size() -> usize { /// /// [`RETURN_DATA_COPY`]: https://www.evm.codes/#3e #[no_mangle] -pub unsafe extern "C" fn read_return_data( +unsafe extern "C" fn read_return_data( dest: *mut u8, _offset: usize, size: usize, ) -> usize { - Context::current().read_return_data_raw(dest, size) + VMContext::current().read_return_data_raw(dest, size) } /// Calls the contract at the given address with options for passing value and /// to limit the amount of gas supplied. The return status indicates whether the -/// call succeeded, and is nonzero on failure. +/// call succeeded and is nonzero on failure. /// -/// In both cases `return_data_len` will store the length of the result, the +/// In both cases, `return_data_len` will store the length of the result, the /// bytes of which can be read via the `read_return_data` hostio. The bytes are /// not returned directly so that the programmer can potentially save gas by /// choosing which subset of the return result they'd like to copy. /// /// The semantics are equivalent to that of the EVM's [`CALL`] opcode, including -/// callvalue stipends and the 63/64 gas rule. This means that supplying the +/// call value stipends and the 63/64 gas rule. This means that supplying the /// `u64::MAX` gas can be used to send as much as possible. /// /// [`CALL`]: https://www.evm.codes/#f1 #[no_mangle] -pub unsafe extern "C" fn call_contract( +unsafe extern "C" fn call_contract( contract: *const u8, calldata: *const u8, calldata_len: usize, - _value: *const u8, + value: *const u8, _gas: u64, return_data_len: *mut usize, ) -> u8 { - Context::current().call_contract_raw( + VMContext::current().call_contract_with_value_raw( contract, calldata, calldata_len, + value, return_data_len, ) } @@ -303,14 +292,14 @@ pub unsafe extern "C" fn call_contract( /// /// [`STATIC_CALL`]: https://www.evm.codes/#FA #[no_mangle] -pub unsafe extern "C" fn static_call_contract( +unsafe extern "C" fn static_call_contract( contract: *const u8, calldata: *const u8, calldata_len: usize, _gas: u64, return_data_len: *mut usize, ) -> u8 { - Context::current().call_contract_raw( + VMContext::current().call_contract_raw( contract, calldata, calldata_len, @@ -320,9 +309,9 @@ pub unsafe extern "C" fn static_call_contract( /// Delegate calls the contract at the given address, with the option to limit /// the amount of gas supplied. The return status indicates whether the call -/// succeeded, and is nonzero on failure. +/// succeeded and is nonzero on failure. /// -/// In both cases `return_data_len` will store the length of the result, the +/// In both cases, `return_data_len` will store the length of the result, the /// bytes of which can be read via the `read_return_data` hostio. The bytes are /// not returned directly so that the programmer can potentially save gas by /// choosing which subset of the return result they'd like to copy. @@ -333,14 +322,14 @@ pub unsafe extern "C" fn static_call_contract( /// /// [`DELEGATE_CALL`]: https://www.evm.codes/#F4 #[no_mangle] -pub unsafe extern "C" fn delegate_call_contract( +unsafe extern "C" fn delegate_call_contract( contract: *const u8, calldata: *const u8, calldata_len: usize, _gas: u64, return_data_len: *mut usize, ) -> u8 { - Context::current().call_contract_raw( + VMContext::current().call_contract_raw( contract, calldata, calldata_len, @@ -354,7 +343,7 @@ pub unsafe extern "C" fn delegate_call_contract( /// /// [`Block Numbers and Time`]: https://developer.arbitrum.io/time #[no_mangle] -pub unsafe extern "C" fn block_timestamp() -> u64 { +unsafe extern "C" fn block_timestamp() -> u64 { // Epoch timestamp: 1st January 2025 00::00::00 1_735_689_600 } diff --git a/crates/motsu/src/storage_access.rs b/crates/motsu/src/storage_access.rs new file mode 100644 index 0000000..8a583fe --- /dev/null +++ b/crates/motsu/src/storage_access.rs @@ -0,0 +1,65 @@ +//! Tooling for accessing Motsu test VM storage. + +use std::hash::Hash; + +use dashmap::{mapref::one::RefMut, try_result::TryResult, DashMap}; + +/// Trait for Motsu test VM storage access. +pub(crate) trait AccessStorage { + type Key; + type Value; + + /// Get mutable access to storage with `key`. + /// + /// # Panics + /// + /// * After 10 attempts to access the storage. + /// * If the contract wasn't initialized. + fn access_storage( + &self, + key: &Self::Key, + ) -> RefMut { + self.access_storage_with_backoff(key, 10) + } + + /// Get mutable access to storage with `key`, with `backoff` number of + /// attempts. + /// + /// # Panics + /// + /// * After `backoff` attempts to access the storage. + /// * If the contract wasn't initialized. + fn access_storage_with_backoff( + &self, + key: &Self::Key, + backoff: u32, + ) -> RefMut; +} + +impl AccessStorage for DashMap { + type Key = K; + type Value = V; + + fn access_storage_with_backoff( + &self, + key: &Self::Key, + backoff: u32, + ) -> RefMut { + { + match self.try_get_mut(key) { + TryResult::Present(router) => router, + TryResult::Absent => { + panic!("contract should be initialised first") + } + TryResult::Locked => { + if backoff == 0 { + panic!("storage is locked") + } else { + std::thread::sleep(std::time::Duration::from_millis(1)); + self.access_storage_with_backoff(key, backoff - 1) + } + } + } + } + } +}