diff --git a/evm-adapters/src/sputnik/cheatcodes/cheatcode_handler.rs b/evm-adapters/src/sputnik/cheatcodes/cheatcode_handler.rs index f70fd9c77..d0f5aa6df 100644 --- a/evm-adapters/src/sputnik/cheatcodes/cheatcode_handler.rs +++ b/evm-adapters/src/sputnik/cheatcodes/cheatcode_handler.rs @@ -199,7 +199,6 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> SputnikExecutor return (e.into(), Vec::new()), } - // Initialize initial addresses for EIP-2929 // Initialize initial addresses for EIP-2929 if self.config().increase_state_access_gas { let addresses = core::iter::once(caller).chain(core::iter::once(address)); @@ -224,6 +223,20 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> SputnikExecutor { self.state_mut().increment_call_index(); + + // check if all expected calls were made + if let Some((address, expecteds)) = + self.state().expected_calls.iter().find(|(_, expecteds)| !expecteds.is_empty()) + { + return ( + ExitReason::Revert(ExitRevert::Reverted), + ethers::abi::encode(&[Token::String(format!( + "Expected a call to 0x{} with data {}, but got none", + address, + ethers::types::Bytes::from(expecteds[0].clone()) + ))]), + ) + } (s, v) } Capture::Trap(_) => { @@ -723,6 +736,22 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> CheatcodeStackExecutor<'a, 'b, B, P> }; self.state_mut().expected_emits.push(expected_emit); } + HEVMCalls::MockCall(inner) => { + self.add_debug(CheatOp::MOCKCALL); + self.state_mut() + .mocked_calls + .entry(inner.0) + .or_default() + .insert(inner.1.to_vec(), inner.2.to_vec()); + } + HEVMCalls::ClearMockedCalls(_) => { + self.add_debug(CheatOp::CLEARMOCKEDCALLS); + self.state_mut().mocked_calls = Default::default(); + } + HEVMCalls::ExpectCall(inner) => { + self.add_debug(CheatOp::EXPECTCALL); + self.state_mut().expected_calls.entry(inner.0).or_default().push(inner.1.to_vec()); + } }; self.fill_trace(&trace, true, Some(res.clone()), pre_index); @@ -1373,6 +1402,32 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> Handler for CheatcodeStackExecutor<'a } } + // handle expected calls + if let Some(expecteds) = self.state_mut().expected_calls.get_mut(&code_address) { + if let Some(found_match) = + expecteds.iter().position(|expected| expected == &input[..expected.len()]) + { + expecteds.remove(found_match); + } + } + + // handle mocked calls + if let Some(mocks) = self.state().mocked_calls.get(&code_address) { + if let Some(mock_retdata) = mocks.get(&input) { + return Capture::Exit(( + ExitReason::Succeed(ExitSucceed::Returned), + mock_retdata.clone(), + )) + } else if let Some((_, mock_retdata)) = + mocks.iter().find(|(mock, _)| *mock == &input[..mock.len()]) + { + return Capture::Exit(( + ExitReason::Succeed(ExitSucceed::Returned), + mock_retdata.clone(), + )) + } + } + // perform the call let res = self.call_inner( code_address, diff --git a/evm-adapters/src/sputnik/cheatcodes/debugger.rs b/evm-adapters/src/sputnik/cheatcodes/debugger.rs index 92dc4b9cb..968602d8d 100644 --- a/evm-adapters/src/sputnik/cheatcodes/debugger.rs +++ b/evm-adapters/src/sputnik/cheatcodes/debugger.rs @@ -175,6 +175,9 @@ pub enum CheatOp { RECORD, ACCESSES, EXPECTEMIT, + MOCKCALL, + CLEARMOCKEDCALLS, + EXPECTCALL, } impl From for OpCode { @@ -204,6 +207,9 @@ impl CheatOp { CheatOp::RECORD => "VM_RECORD", CheatOp::ACCESSES => "VM_ACCESSES", CheatOp::EXPECTEMIT => "VM_EXPECTEMIT", + CheatOp::MOCKCALL => "VM_MOCKCALL", + CheatOp::CLEARMOCKEDCALLS => "VM_CLEARMOCKEDCALLS", + CheatOp::EXPECTCALL => "VM_EXPECTCALL", } } } diff --git a/evm-adapters/src/sputnik/cheatcodes/memory_stackstate_owned.rs b/evm-adapters/src/sputnik/cheatcodes/memory_stackstate_owned.rs index 5cee85fbf..0694954e0 100644 --- a/evm-adapters/src/sputnik/cheatcodes/memory_stackstate_owned.rs +++ b/evm-adapters/src/sputnik/cheatcodes/memory_stackstate_owned.rs @@ -57,6 +57,8 @@ pub struct MemoryStackStateOwned<'config, B> { pub all_logs: Vec, /// Expected events by end of the next call pub expected_emits: Vec, + pub mocked_calls: BTreeMap, Vec>>, + pub expected_calls: BTreeMap>>, /// Debug enabled pub debug_enabled: bool, /// An arena allocator of DebugNodes for debugging purposes @@ -124,6 +126,8 @@ impl<'config, B: Backend> MemoryStackStateOwned<'config, B> { accesses: None, all_logs: Default::default(), expected_emits: Default::default(), + mocked_calls: Default::default(), + expected_calls: Default::default(), debug_enabled, debug_steps: vec![Default::default()], debug_instruction_pointers: (BTreeMap::new(), BTreeMap::new()), diff --git a/evm-adapters/src/sputnik/cheatcodes/mod.rs b/evm-adapters/src/sputnik/cheatcodes/mod.rs index b8cd5b6b3..a46d1c1a9 100644 --- a/evm-adapters/src/sputnik/cheatcodes/mod.rs +++ b/evm-adapters/src/sputnik/cheatcodes/mod.rs @@ -63,6 +63,9 @@ ethers::contract::abigen!( record() accesses(address)(bytes32[],bytes32[]) expectEmit(bool,bool,bool,bool) + mockCall(address,bytes,bytes) + clearMockedCalls() + expectCall(address,bytes) ]"#, ); pub use hevm_mod::{HEVMCalls, HEVM_ABI}; diff --git a/evm-adapters/testdata/CheatCodes.sol b/evm-adapters/testdata/CheatCodes.sol index 05712007f..38ee46291 100644 --- a/evm-adapters/testdata/CheatCodes.sol +++ b/evm-adapters/testdata/CheatCodes.sol @@ -42,6 +42,16 @@ interface Hevm { // Call this function, then emit an event, then call a function. Internally after the call, we check if // logs were emited in the expected order with the expected topics and data (as specified by the booleans) function expectEmit(bool,bool,bool,bool) external; + // Mocks a call to an address, returning specified data. + // Calldata can either be strict or a partial match, e.g. if you only + // pass a Solidity selector to the expected calldata, then the entire Solidity + // function will be mocked. + function mockCall(address,bytes calldata,bytes calldata) external; + // Clears all mocked calls + function clearMockedCalls() external; + // Expect a call to an address with the specified calldata. + // Calldata can either be strict or a partial match + function expectCall(address,bytes calldata) external; } contract HasStorage { @@ -358,7 +368,154 @@ contract CheatCodes is DSTest { // after expectRevert function testFailExpectRevert3() public { hevm.expectRevert("revert"); - } + } + + function testMockArbitraryCall() public { + hevm.mockCall(address(0xbeef), abi.encode("wowee"), abi.encode("epic")); + (bool ok, bytes memory ret) = address(0xbeef).call(abi.encode("wowee")); + assertTrue(ok); + assertEq(abi.decode(ret, (string)), "epic"); + } + + function testMockContract() public { + MockMe target = new MockMe(); + + // pre-mock + assertEq(target.numberA(), 1); + assertEq(target.numberB(), 2); + + hevm.mockCall( + address(target), + abi.encodeWithSelector(target.numberB.selector), + abi.encode(10) + ); + + // post-mock + assertEq(target.numberA(), 1); + assertEq(target.numberB(), 10); + } + + function testMockInner() public { + MockMe inner = new MockMe(); + MockInner target = new MockInner(address(inner)); + + // pre-mock + assertEq(target.sum(), 3); + + hevm.mockCall( + address(inner), + abi.encodeWithSelector(inner.numberB.selector), + abi.encode(9) + ); + + // post-mock + assertEq(target.sum(), 10); + } + + function testMockSelector() public { + MockMe target = new MockMe(); + assertEq(target.add(5, 5), 10); + + hevm.mockCall( + address(target), + abi.encodeWithSelector(target.add.selector), + abi.encode(11) + ); + + assertEq(target.add(5, 5), 11); + } + + function testMockCalldata() public { + MockMe target = new MockMe(); + assertEq(target.add(5, 5), 10); + assertEq(target.add(6, 4), 10); + + hevm.mockCall( + address(target), + abi.encodeWithSelector(target.add.selector, 5, 5), + abi.encode(11) + ); + + assertEq(target.add(5, 5), 11); + assertEq(target.add(6, 4), 10); + } + + function testClearMockedCalls() public { + MockMe target = new MockMe(); + + hevm.mockCall( + address(target), + abi.encodeWithSelector(target.numberB.selector), + abi.encode(10) + ); + + assertEq(target.numberA(), 1); + assertEq(target.numberB(), 10); + + hevm.clearMockedCalls(); + + assertEq(target.numberA(), 1); + assertEq(target.numberB(), 2); + } + + function testExpectCallWithData() public { + MockMe target = new MockMe(); + hevm.expectCall( + address(target), + abi.encodeWithSelector(target.add.selector, 1, 2) + ); + target.add(1, 2); + } + + function testFailExpectCallWithData() public { + MockMe target = new MockMe(); + hevm.expectCall( + address(target), + abi.encodeWithSelector(target.add.selector, 1, 2) + ); + target.add(3, 3); + } + + function testExpectInnerCall() public { + MockMe inner = new MockMe(); + MockInner target = new MockInner(address(inner)); + + hevm.expectCall( + address(inner), + abi.encodeWithSelector(inner.numberB.selector) + ); + target.sum(); + } + + function testFailExpectInnerCall() public { + MockMe inner = new MockMe(); + MockInner target = new MockInner(address(inner)); + + hevm.expectCall( + address(inner), + abi.encodeWithSelector(inner.numberB.selector) + ); + + // this function does not call inner + target.hello(); + } + + function testExpectSelectorCall() public { + MockMe target = new MockMe(); + hevm.expectCall( + address(target), + abi.encodeWithSelector(target.add.selector) + ); + target.add(5, 5); + } + + function testFailExpectSelectorCall() public { + MockMe target = new MockMe(); + hevm.expectCall( + address(target), + abi.encodeWithSelector(target.add.selector) + ); + } function getCode(address who) internal returns (bytes memory o_code) { assembly { @@ -511,3 +668,32 @@ contract ExpectEmit { } } +contract MockMe { + function numberA() public returns (uint256) { + return 1; + } + + function numberB() public returns (uint256) { + return 2; + } + + function add(uint256 a, uint256 b) public returns (uint256) { + return a + b; + } +} + +contract MockInner { + MockMe private inner; + + constructor(address _inner) { + inner = MockMe(_inner); + } + + function sum() public returns (uint256) { + return inner.numberA() + inner.numberB(); + } + + function hello() public returns (string memory) { + return "hi"; + } +} diff --git a/forge/README.md b/forge/README.md index 7400ddd2e..6434b6d5b 100644 --- a/forge/README.md +++ b/forge/README.md @@ -240,6 +240,16 @@ interface Vm { function expectRevert(bytes calldata) external; // Expects the next emitted event. Params check topic 1, topic 2, topic 3 and data are the same. function expectEmit(bool, bool, bool, bool) external; + // Mocks a call to an address, returning specified data. + // Calldata can either be strict or a partial match, e.g. if you only + // pass a Solidity selector to the expected calldata, then the entire Solidity + // function will be mocked. + function mockCall(address,bytes calldata,bytes calldata) external; + // Clears all mocked calls + function clearMockedCalls() external; + // Expect a call to an address with the specified calldata. + // Calldata can either be strict or a partial match + function expectCall(address,bytes calldata) external; } ``` ### `console.log`