Skip to content

Commit

Permalink
fix(forge): stack pranks, restore pranks at earlier call depths (#10018)
Browse files Browse the repository at this point in the history
  • Loading branch information
grandizzy authored Mar 7, 2025
1 parent f474801 commit 23191fb
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 59 deletions.
10 changes: 4 additions & 6 deletions crates/cheatcodes/src/evm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ impl Cheatcode for coolCall {
impl Cheatcode for readCallersCall {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self {} = self;
read_callers(ccx.state, &ccx.ecx.env.tx.caller)
read_callers(ccx.state, &ccx.ecx.env.tx.caller, ccx.ecx.journaled_state.depth())
}
}

Expand Down Expand Up @@ -1068,19 +1068,17 @@ fn derive_snapshot_name(
/// - If no caller modification is active:
/// - caller_mode will be equal to [CallerMode::None],
/// - `msg.sender` and `tx.origin` will be equal to the default sender address.
fn read_callers(state: &Cheatcodes, default_sender: &Address) -> Result {
let Cheatcodes { prank, broadcast, .. } = state;

fn read_callers(state: &Cheatcodes, default_sender: &Address, call_depth: u64) -> Result {
let mut mode = CallerMode::None;
let mut new_caller = default_sender;
let mut new_origin = default_sender;
if let Some(prank) = prank {
if let Some(prank) = state.get_prank(call_depth) {
mode = if prank.single_call { CallerMode::Prank } else { CallerMode::RecurrentPrank };
new_caller = &prank.new_caller;
if let Some(new) = &prank.new_origin {
new_origin = new;
}
} else if let Some(broadcast) = broadcast {
} else if let Some(broadcast) = &state.broadcast {
mode = if broadcast.single_call {
CallerMode::Broadcast
} else {
Expand Down
39 changes: 20 additions & 19 deletions crates/cheatcodes/src/evm/prank.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{Cheatcode, Cheatcodes, CheatsCtxt, Result, Vm::*};
use crate::{Cheatcode, CheatsCtxt, Result, Vm::*};
use alloy_primitives::Address;

/// Prank information.
#[derive(Clone, Debug, Default)]
#[derive(Clone, Copy, Debug, Default)]
pub struct Prank {
/// Address of the contract that initiated the prank
pub prank_caller: Address,
Expand Down Expand Up @@ -45,13 +45,13 @@ impl Prank {
}
}

/// Apply the prank by setting `used` to true iff it is false
/// Apply the prank by setting `used` to true if it is false
/// Only returns self in the case it is updated (first application)
pub fn first_time_applied(&self) -> Option<Self> {
if self.used {
None
} else {
Some(Self { used: true, ..self.clone() })
Some(Self { used: true, ..*self })
}
}
}
Expand Down Expand Up @@ -113,9 +113,9 @@ impl Cheatcode for startPrank_3Call {
}

impl Cheatcode for stopPrankCall {
fn apply(&self, state: &mut Cheatcodes) -> Result {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self {} = self;
state.prank = None;
ccx.state.pranks.remove(&ccx.ecx.journaled_state.depth());
Ok(Default::default())
}
}
Expand All @@ -127,39 +127,40 @@ fn prank(
single_call: bool,
delegate_call: bool,
) -> Result {
let prank = Prank::new(
ccx.caller,
ccx.ecx.env.tx.caller,
*new_caller,
new_origin.copied(),
ccx.ecx.journaled_state.depth(),
single_call,
delegate_call,
);

// Ensure that code exists at `msg.sender` if delegate calling.
if delegate_call {
let code = ccx.code(*new_caller)?;
ensure!(!code.is_empty(), "cannot `prank` delegate call from an EOA");
}

if let Some(Prank { used, single_call: current_single_call, .. }) = ccx.state.prank {
let depth = ccx.ecx.journaled_state.depth();
if let Some(Prank { used, single_call: current_single_call, .. }) = ccx.state.get_prank(depth) {
ensure!(used, "cannot overwrite a prank until it is applied at least once");
// This case can only fail if the user calls `vm.startPrank` and then `vm.prank` later on.
// This should not be possible without first calling `stopPrank`
ensure!(
single_call == current_single_call,
single_call == *current_single_call,
"cannot override an ongoing prank with a single vm.prank; \
use vm.startPrank to override the current prank"
);
}

let prank = Prank::new(
ccx.caller,
ccx.ecx.env.tx.caller,
*new_caller,
new_origin.copied(),
depth,
single_call,
delegate_call,
);

ensure!(
ccx.state.broadcast.is_none(),
"cannot `prank` for a broadcasted transaction; \
pass the desired `tx.origin` into the `broadcast` cheatcode call"
);

ccx.state.prank = Some(prank);
ccx.state.pranks.insert(prank.depth, prank);
Ok(Default::default())
}
66 changes: 37 additions & 29 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ pub struct Cheatcodes {
/// Address labels
pub labels: AddressHashMap<String>,

/// Prank information
pub prank: Option<Prank>,
/// Prank information, mapped to the call depth where pranks were added.
pub pranks: BTreeMap<u64, Prank>,

/// Expected revert information
pub expected_revert: Option<ExpectedRevert>,
Expand Down Expand Up @@ -511,7 +511,7 @@ impl Cheatcodes {
block: Default::default(),
active_delegation: Default::default(),
gas_price: Default::default(),
prank: Default::default(),
pranks: Default::default(),
expected_revert: Default::default(),
assume_no_revert: Default::default(),
fork_revert_diagnostic: Default::default(),
Expand Down Expand Up @@ -543,6 +543,13 @@ impl Cheatcodes {
}
}

/// Returns the configured prank at given depth or the first prank configured at a lower depth.
/// For example, if pranks configured for depth 1, 3 and 5, the prank for depth 4 is the one
/// configured at depth 3.
pub fn get_prank(&self, depth: u64) -> Option<&Prank> {
self.pranks.range(..=depth).last().map(|(_, prank)| prank)
}

/// Returns the configured wallets if available, else creates a new instance.
pub fn wallets(&mut self) -> &Wallets {
self.wallets.get_or_insert_with(|| Wallets::new(MultiWallet::default(), None))
Expand Down Expand Up @@ -637,10 +644,11 @@ impl Cheatcodes {
{
let ecx = &mut ecx.inner;
let gas = Gas::new(input.gas_limit());
let curr_depth = ecx.journaled_state.depth();

// Apply our prank
if let Some(prank) = &self.prank {
if ecx.journaled_state.depth() >= prank.depth && input.caller() == prank.prank_caller {
if let Some(prank) = &self.get_prank(curr_depth) {
if curr_depth >= prank.depth && input.caller() == prank.prank_caller {
// At the target depth we set `msg.sender`
if ecx.journaled_state.depth() == prank.depth {
input.set_caller(prank.new_caller);
Expand All @@ -655,9 +663,7 @@ impl Cheatcodes {

// Apply our broadcast
if let Some(broadcast) = &self.broadcast {
if ecx.journaled_state.depth() >= broadcast.depth &&
input.caller() == broadcast.original_caller
{
if curr_depth >= broadcast.depth && input.caller() == broadcast.original_caller {
if let Err(err) =
ecx.journaled_state.load_account(broadcast.new_origin, &mut ecx.db)
{
Expand All @@ -673,7 +679,7 @@ impl Cheatcodes {

ecx.env.tx.caller = broadcast.new_origin;

if ecx.journaled_state.depth() == broadcast.depth {
if curr_depth == broadcast.depth {
input.set_caller(broadcast.new_origin);
let is_fixed_gas_limit = check_if_fixed_gas_limit(ecx, input.gas_limit());

Expand Down Expand Up @@ -718,7 +724,7 @@ impl Cheatcodes {
reverted: false,
deployedCode: Bytes::new(), // updated on (eof)create_end
storageAccesses: vec![], // updated on (eof)create_end
depth: ecx.journaled_state.depth(),
depth: curr_depth,
}]);
}

Expand All @@ -734,22 +740,23 @@ impl Cheatcodes {
) -> CreateOutcome
where {
let ecx = &mut ecx.inner;
let curr_depth = ecx.journaled_state.depth();

// Clean up pranks
if let Some(prank) = &self.prank {
if ecx.journaled_state.depth() == prank.depth {
if let Some(prank) = &self.get_prank(curr_depth) {
if curr_depth == prank.depth {
ecx.env.tx.caller = prank.prank_origin;

// Clean single-call prank once we have returned to the original depth
if prank.single_call {
std::mem::take(&mut self.prank);
std::mem::take(&mut self.pranks);
}
}
}

// Clean up broadcasts
if let Some(broadcast) = &self.broadcast {
if ecx.journaled_state.depth() == broadcast.depth {
if curr_depth == broadcast.depth {
ecx.env.tx.caller = broadcast.original_origin;

// Clean single-call broadcast once we have returned to the original depth
Expand All @@ -761,7 +768,7 @@ where {

// Handle expected reverts
if let Some(expected_revert) = &self.expected_revert {
if ecx.journaled_state.depth() <= expected_revert.depth &&
if curr_depth <= expected_revert.depth &&
matches!(expected_revert.kind, ExpectedRevertKind::Default)
{
let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
Expand Down Expand Up @@ -798,7 +805,7 @@ where {
// previous call depth's recorded accesses, if any
if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack {
// The root call cannot be recorded.
if ecx.journaled_state.depth() > 0 {
if curr_depth > 0 {
if let Some(last_depth) = &mut recorded_account_diffs_stack.pop() {
// Update the reverted status of all deeper calls if this call reverted, in
// accordance with EVM behavior
Expand Down Expand Up @@ -879,11 +886,12 @@ where {
executor: &mut impl CheatcodesExecutor,
) -> Option<CallOutcome> {
let gas = Gas::new(call.gas_limit);
let curr_depth = ecx.journaled_state.depth();

// At the root call to test function or script `run()`/`setUp()` functions, we are
// decreasing sender nonce to ensure that it matches on-chain nonce once we start
// broadcasting.
if ecx.journaled_state.depth == 0 {
if curr_depth == 0 {
let sender = ecx.env.tx.caller;
let account = match super::evm::journaled_account(ecx, sender) {
Ok(account) => account,
Expand Down Expand Up @@ -991,7 +999,7 @@ where {
}

// Apply our prank
if let Some(prank) = &self.prank {
if let Some(prank) = &self.get_prank(curr_depth) {
// Apply delegate call, `call.caller`` will not equal `prank.prank_caller`
if let CallScheme::DelegateCall | CallScheme::ExtDelegateCall = call.scheme {
if prank.delegate_call {
Expand All @@ -1005,11 +1013,11 @@ where {
}
}

if ecx.journaled_state.depth() >= prank.depth && call.caller == prank.prank_caller {
if curr_depth >= prank.depth && call.caller == prank.prank_caller {
let mut prank_applied = false;

// At the target depth we set `msg.sender`
if ecx.journaled_state.depth() == prank.depth {
if curr_depth == prank.depth {
call.caller = prank.new_caller;
prank_applied = true;
}
Expand All @@ -1023,7 +1031,7 @@ where {
// If prank applied for first time, then update
if prank_applied {
if let Some(applied_prank) = prank.first_time_applied() {
self.prank = Some(applied_prank);
self.pranks.insert(curr_depth, applied_prank);
}
}
}
Expand All @@ -1035,9 +1043,7 @@ where {
//
// We do this because any subsequent contract calls *must* exist on chain and
// we only want to grab *this* call, not internal ones
if ecx.journaled_state.depth() == broadcast.depth &&
call.caller == broadcast.original_caller
{
if curr_depth == broadcast.depth && call.caller == broadcast.original_caller {
// At the target depth we set `msg.sender` & tx.origin.
// We are simulating the caller as being an EOA, so *both* must be set to the
// broadcast.origin.
Expand Down Expand Up @@ -1304,13 +1310,14 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
// This should be placed before the revert handling, because we might exit early there
if !cheatcode_call {
// Clean up pranks
if let Some(prank) = &self.prank {
if ecx.journaled_state.depth() == prank.depth {
let curr_depth = ecx.journaled_state.depth();
if let Some(prank) = &self.get_prank(curr_depth) {
if curr_depth == prank.depth {
ecx.env.tx.caller = prank.prank_origin;

// Clean single-call prank once we have returned to the original depth
if prank.single_call {
let _ = self.prank.take();
self.pranks.remove(&curr_depth);
}
}
}
Expand Down Expand Up @@ -1719,15 +1726,16 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
impl InspectorExt for Cheatcodes {
fn should_use_create2_factory(&mut self, ecx: Ecx, inputs: &mut CreateInputs) -> bool {
if let CreateScheme::Create2 { .. } = inputs.scheme {
let target_depth = if let Some(prank) = &self.prank {
let depth = ecx.journaled_state.depth();
let target_depth = if let Some(prank) = &self.get_prank(depth) {
prank.depth
} else if let Some(broadcast) = &self.broadcast {
broadcast.depth
} else {
1
};

ecx.journaled_state.depth() == target_depth &&
depth == target_depth &&
(self.broadcast.is_some() || self.config.always_use_create_2_factory)
} else {
false
Expand Down
5 changes: 3 additions & 2 deletions crates/cheatcodes/src/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,9 @@ impl Wallets {

/// Sets up broadcasting from a script using `new_origin` as the sender.
fn broadcast(ccx: &mut CheatsCtxt, new_origin: Option<&Address>, single_call: bool) -> Result {
let depth = ccx.ecx.journaled_state.depth();
ensure!(
ccx.state.prank.is_none(),
ccx.state.get_prank(depth).is_none(),
"you have an active prank; broadcasting and pranks are not compatible"
);
ensure!(ccx.state.broadcast.is_none(), "a broadcast is active already");
Expand All @@ -269,7 +270,7 @@ fn broadcast(ccx: &mut CheatsCtxt, new_origin: Option<&Address>, single_call: bo
new_origin: new_origin.unwrap_or(ccx.ecx.env.tx.caller),
original_caller: ccx.caller,
original_origin: ccx.ecx.env.tx.caller,
depth: ccx.ecx.journaled_state.depth(),
depth,
single_call,
};
debug!(target: "cheatcodes", ?broadcast, "started");
Expand Down
Loading

0 comments on commit 23191fb

Please sign in to comment.