Skip to content

Commit

Permalink
feat(cheatcodes): Add vm.mockCalls to mock different return data fo…
Browse files Browse the repository at this point in the history
…r multiple calls (#9024)

* Refactor vm.mockCall to be based on mutable VecDeque

* Add vm.mockCalls cheatcode

* Refactor mock_call to be wrapper for mock_calls

* Add a test to vm.mockCalls

* Add test for vm.mockCalls with msg.value

* Fix fmt & clippy following vm.mockCalls implementation

* Fix Solidity fmt in testdata/default/cheats/MockCalls.t.sol

* Add test in MockCalls.t.sol to check last mocked data persists

* Remove allow(clippy::ptr_arg) from mock_call & mock_calls

* Apply suggestions from code review

---------

Co-authored-by: zerosnacks <[email protected]>
Co-authored-by: DaniPopes <[email protected]>
  • Loading branch information
3 people authored Oct 7, 2024
1 parent 22a72d5 commit d7d9b40
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 18 deletions.
40 changes: 40 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,14 @@ interface Vm {
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCall(address callee, uint256 msgValue, bytes calldata data, bytes calldata returnData) external;

/// Mocks multiple calls to an address, returning specified data for each call.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCalls(address callee, bytes calldata data, bytes[] calldata returnData) external;

/// Mocks multiple calls to an address with a specific `msg.value`, returning specified data for each call.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCalls(address callee, uint256 msgValue, bytes calldata data, bytes[] calldata returnData) external;

/// Reverts a call to an address with specified revert data.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCallRevert(address callee, bytes calldata data, bytes calldata revertData) external;
Expand Down
38 changes: 35 additions & 3 deletions crates/cheatcodes/src/evm/mock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{inspector::InnerEcx, Cheatcode, Cheatcodes, CheatsCtxt, Result, Vm::*};
use alloy_primitives::{Address, Bytes, U256};
use revm::{interpreter::InstructionResult, primitives::Bytecode};
use std::cmp::Ordering;
use std::{cmp::Ordering, collections::VecDeque};

/// Mocked call data.
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -65,6 +65,25 @@ impl Cheatcode for mockCall_1Call {
}
}

impl Cheatcode for mockCalls_0Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self { callee, data, returnData } = self;
let _ = make_acc_non_empty(callee, ccx.ecx)?;

mock_calls(ccx.state, callee, data, None, returnData, InstructionResult::Return);
Ok(Default::default())
}
}

impl Cheatcode for mockCalls_1Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self { callee, msgValue, data, returnData } = self;
ccx.ecx.load_account(*callee)?;
mock_calls(ccx.state, callee, data, Some(msgValue), returnData, InstructionResult::Return);
Ok(Default::default())
}
}

impl Cheatcode for mockCallRevert_0Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self { callee, data, revertData } = self;
Expand Down Expand Up @@ -94,18 +113,31 @@ impl Cheatcode for mockFunctionCall {
}
}

#[allow(clippy::ptr_arg)] // Not public API, doesn't matter
fn mock_call(
state: &mut Cheatcodes,
callee: &Address,
cdata: &Bytes,
value: Option<&U256>,
rdata: &Bytes,
ret_type: InstructionResult,
) {
mock_calls(state, callee, cdata, value, std::slice::from_ref(rdata), ret_type)
}

fn mock_calls(
state: &mut Cheatcodes,
callee: &Address,
cdata: &Bytes,
value: Option<&U256>,
rdata_vec: &[Bytes],
ret_type: InstructionResult,
) {
state.mocked_calls.entry(*callee).or_default().insert(
MockCallDataContext { calldata: Bytes::copy_from_slice(cdata), value: value.copied() },
MockCallReturnData { ret_type, data: Bytes::copy_from_slice(rdata) },
rdata_vec
.iter()
.map(|rdata| MockCallReturnData { ret_type, data: rdata.clone() })
.collect::<VecDeque<_>>(),
);
}

Expand Down
40 changes: 25 additions & 15 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ pub struct Cheatcodes {

/// Mocked calls
// **Note**: inner must a BTreeMap because of special `Ord` impl for `MockCallDataContext`
pub mocked_calls: HashMap<Address, BTreeMap<MockCallDataContext, MockCallReturnData>>,
pub mocked_calls: HashMap<Address, BTreeMap<MockCallDataContext, VecDeque<MockCallReturnData>>>,

/// Mocked functions. Maps target address to be mocked to pair of (calldata, mock address).
pub mocked_functions: HashMap<Address, HashMap<Bytes, Address>>,
Expand Down Expand Up @@ -889,26 +889,36 @@ where {
}

// Handle mocked calls
if let Some(mocks) = self.mocked_calls.get(&call.bytecode_address) {
if let Some(mocks) = self.mocked_calls.get_mut(&call.bytecode_address) {
let ctx =
MockCallDataContext { calldata: call.input.clone(), value: call.transfer_value() };
if let Some(return_data) = mocks.get(&ctx).or_else(|| {
mocks
.iter()

if let Some(return_data_queue) = match mocks.get_mut(&ctx) {
Some(queue) => Some(queue),
None => mocks
.iter_mut()
.find(|(mock, _)| {
call.input.get(..mock.calldata.len()) == Some(&mock.calldata[..]) &&
mock.value.map_or(true, |value| Some(value) == call.transfer_value())
})
.map(|(_, v)| v)
}) {
return Some(CallOutcome {
result: InterpreterResult {
result: return_data.ret_type,
output: return_data.data.clone(),
gas,
},
memory_offset: call.return_memory_offset.clone(),
});
.map(|(_, v)| v),
} {
if let Some(return_data) = if return_data_queue.len() == 1 {
// If the mocked calls stack has a single element in it, don't empty it
return_data_queue.front().map(|x| x.to_owned())
} else {
// Else, we pop the front element
return_data_queue.pop_front()
} {
return Some(CallOutcome {
result: InterpreterResult {
result: return_data.ret_type,
output: return_data.data,
gas,
},
memory_offset: call.return_memory_offset.clone(),
});
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions testdata/cheats/Vm.sol

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions testdata/default/cheats/MockCalls.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT OR Apache-2.0
pragma solidity 0.8.18;

import "ds-test/test.sol";
import "cheats/Vm.sol";

contract MockCallsTest is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);

function testMockCallsLastShouldPersist() public {
address mockUser = vm.addr(vm.randomUint());
address mockErc20 = vm.addr(vm.randomUint());
bytes memory data = abi.encodeWithSignature("balanceOf(address)", mockUser);
bytes[] memory mocks = new bytes[](2);
mocks[0] = abi.encode(2 ether);
mocks[1] = abi.encode(7.219 ether);
vm.mockCalls(mockErc20, data, mocks);
(, bytes memory ret1) = mockErc20.call(data);
assertEq(abi.decode(ret1, (uint256)), 2 ether);
(, bytes memory ret2) = mockErc20.call(data);
assertEq(abi.decode(ret2, (uint256)), 7.219 ether);
(, bytes memory ret3) = mockErc20.call(data);
assertEq(abi.decode(ret3, (uint256)), 7.219 ether);
}

function testMockCallsWithValue() public {
address mockUser = vm.addr(vm.randomUint());
address mockErc20 = vm.addr(vm.randomUint());
bytes memory data = abi.encodeWithSignature("balanceOf(address)", mockUser);
bytes[] memory mocks = new bytes[](3);
mocks[0] = abi.encode(2 ether);
mocks[1] = abi.encode(1 ether);
mocks[2] = abi.encode(6.423 ether);
vm.mockCalls(mockErc20, 1 ether, data, mocks);
(, bytes memory ret1) = mockErc20.call{value: 1 ether}(data);
assertEq(abi.decode(ret1, (uint256)), 2 ether);
(, bytes memory ret2) = mockErc20.call{value: 1 ether}(data);
assertEq(abi.decode(ret2, (uint256)), 1 ether);
(, bytes memory ret3) = mockErc20.call{value: 1 ether}(data);
assertEq(abi.decode(ret3, (uint256)), 6.423 ether);
}

function testMockCalls() public {
address mockUser = vm.addr(vm.randomUint());
address mockErc20 = vm.addr(vm.randomUint());
bytes memory data = abi.encodeWithSignature("balanceOf(address)", mockUser);
bytes[] memory mocks = new bytes[](3);
mocks[0] = abi.encode(2 ether);
mocks[1] = abi.encode(1 ether);
mocks[2] = abi.encode(6.423 ether);
vm.mockCalls(mockErc20, data, mocks);
(, bytes memory ret1) = mockErc20.call(data);
assertEq(abi.decode(ret1, (uint256)), 2 ether);
(, bytes memory ret2) = mockErc20.call(data);
assertEq(abi.decode(ret2, (uint256)), 1 ether);
(, bytes memory ret3) = mockErc20.call(data);
assertEq(abi.decode(ret3, (uint256)), 6.423 ether);
}
}

0 comments on commit d7d9b40

Please sign in to comment.