From 653eff5c5ab930cea5e69b009d103810850edff6 Mon Sep 17 00:00:00 2001 From: 0xPhaze <103113487+0xPhaze@users.noreply.github.com> Date: Tue, 28 Feb 2023 17:21:08 -0400 Subject: [PATCH] feat(StdAssertions): Add `assertEqCall` (#311) * Add `assertEqCall` to StdAssertions * Update src/StdAssertions.sol Co-authored-by: Matt Solomon * Update src/StdAssertions.sol Co-authored-by: Matt Solomon * Update StdAssertions.t.sol --------- Co-authored-by: Matt Solomon --- src/StdAssertions.sol | 51 +++++++++++++++ test/StdAssertions.t.sol | 131 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+) diff --git a/src/StdAssertions.sol b/src/StdAssertions.sol index 4edecf77..198a1bab 100644 --- a/src/StdAssertions.sol +++ b/src/StdAssertions.sol @@ -322,4 +322,55 @@ abstract contract StdAssertions is DSTest { assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals); } } + + function assertEqCall(address target, bytes memory callDataA, bytes memory callDataB) internal virtual { + assertEqCall(target, callDataA, target, callDataB, true); + } + + function assertEqCall(address targetA, bytes memory callDataA, address targetB, bytes memory callDataB) + internal + virtual + { + assertEqCall(targetA, callDataA, targetB, callDataB, true); + } + + function assertEqCall(address target, bytes memory callDataA, bytes memory callDataB, bool strictRevertData) + internal + virtual + { + assertEqCall(target, callDataA, target, callDataB, strictRevertData); + } + + function assertEqCall( + address targetA, + bytes memory callDataA, + address targetB, + bytes memory callDataB, + bool strictRevertData + ) internal virtual { + (bool successA, bytes memory returnDataA) = address(targetA).call(callDataA); + (bool successB, bytes memory returnDataB) = address(targetB).call(callDataB); + + if (successA && successB) { + assertEq(returnDataA, returnDataB, "Call return data does not match"); + } + + if (!successA && !successB && strictRevertData) { + assertEq(returnDataA, returnDataB, "Call revert data does not match"); + } + + if (!successA && successB) { + emit log("Error: Calls were not equal"); + emit log_named_bytes(" Left call revert data", returnDataA); + emit log_named_bytes(" Right call return data", returnDataB); + fail(); + } + + if (successA && !successB) { + emit log("Error: Calls were not equal"); + emit log_named_bytes(" Left call return data", returnDataA); + emit log_named_bytes(" Right call revert data", returnDataB); + fail(); + } + } } diff --git a/test/StdAssertions.t.sol b/test/StdAssertions.t.sol index 6222654a..2922780c 100644 --- a/test/StdAssertions.t.sol +++ b/test/StdAssertions.t.sol @@ -9,6 +9,12 @@ contract StdAssertionsTest is Test { bool constant EXPECT_PASS = false; bool constant EXPECT_FAIL = true; + bool constant SHOULD_REVERT = true; + bool constant SHOULD_RETURN = false; + + bool constant STRICT_REVERT_DATA = true; + bool constant NON_STRICT_REVERT_DATA = false; + TestTest t = new TestTest(); /*////////////////////////////////////////////////////////////////////////// @@ -607,6 +613,94 @@ contract StdAssertionsTest is Test { emit log_named_string("Error", CUSTOM_ERROR); t._assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals, CUSTOM_ERROR, EXPECT_FAIL); } + + /*////////////////////////////////////////////////////////////////////////// + ASSERT_EQ_CALL + //////////////////////////////////////////////////////////////////////////*/ + + function testAssertEqCall_Return_Pass( + bytes memory callDataA, + bytes memory callDataB, + bytes memory returnData, + bool strictRevertData + ) external { + address targetA = address(new TestMockCall(returnData, SHOULD_RETURN)); + address targetB = address(new TestMockCall(returnData, SHOULD_RETURN)); + + t._assertEqCall(targetA, targetB, callDataA, callDataB, returnData, returnData, strictRevertData, EXPECT_PASS); + } + + function testAssertEqCall_Return_Fail( + bytes memory callDataA, + bytes memory callDataB, + bytes memory returnDataA, + bytes memory returnDataB, + bool strictRevertData + ) external { + vm.assume(keccak256(returnDataA) != keccak256(returnDataB)); + + address targetA = address(new TestMockCall(returnDataA, SHOULD_RETURN)); + address targetB = address(new TestMockCall(returnDataB, SHOULD_RETURN)); + + vm.expectEmit(true, true, true, true); + emit log_named_string("Error", "Call return data does not match"); + t._assertEqCall(targetA, targetB, callDataA, callDataB, returnDataA, returnDataB, strictRevertData, EXPECT_FAIL); + } + + function testAssertEqCall_Revert_Pass( + bytes memory callDataA, + bytes memory callDataB, + bytes memory revertDataA, + bytes memory revertDataB + ) external { + address targetA = address(new TestMockCall(revertDataA, SHOULD_REVERT)); + address targetB = address(new TestMockCall(revertDataB, SHOULD_REVERT)); + + t._assertEqCall( + targetA, targetB, callDataA, callDataB, revertDataA, revertDataB, NON_STRICT_REVERT_DATA, EXPECT_PASS + ); + } + + function testAssertEqCall_Revert_Fail( + bytes memory callDataA, + bytes memory callDataB, + bytes memory revertDataA, + bytes memory revertDataB + ) external { + vm.assume(keccak256(revertDataA) != keccak256(revertDataB)); + + address targetA = address(new TestMockCall(revertDataA, SHOULD_REVERT)); + address targetB = address(new TestMockCall(revertDataB, SHOULD_REVERT)); + + vm.expectEmit(true, true, true, true); + emit log_named_string("Error", "Call revert data does not match"); + t._assertEqCall( + targetA, targetB, callDataA, callDataB, revertDataA, revertDataB, STRICT_REVERT_DATA, EXPECT_FAIL + ); + } + + function testAssertEqCall_Fail( + bytes memory callDataA, + bytes memory callDataB, + bytes memory returnDataA, + bytes memory returnDataB, + bool strictRevertData + ) external { + address targetA = address(new TestMockCall(returnDataA, SHOULD_RETURN)); + address targetB = address(new TestMockCall(returnDataB, SHOULD_REVERT)); + + vm.expectEmit(true, true, true, true); + emit log_named_bytes(" Left call return data", returnDataA); + vm.expectEmit(true, true, true, true); + emit log_named_bytes(" Right call revert data", returnDataB); + t._assertEqCall(targetA, targetB, callDataA, callDataB, returnDataA, returnDataB, strictRevertData, EXPECT_FAIL); + + vm.expectEmit(true, true, true, true); + emit log_named_bytes(" Left call revert data", returnDataB); + vm.expectEmit(true, true, true, true); + emit log_named_bytes(" Right call return data", returnDataA); + t._assertEqCall(targetB, targetA, callDataB, callDataA, returnDataB, returnDataA, strictRevertData, EXPECT_FAIL); + } } contract TestTest is Test { @@ -820,4 +914,41 @@ contract TestTest is Test { ) external expectFailure(expectFail) { assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals, err); } + + function _assertEqCall( + address targetA, + address targetB, + bytes memory callDataA, + bytes memory callDataB, + bytes memory returnDataA, + bytes memory returnDataB, + bool strictRevertData, + bool expectFail + ) external expectFailure(expectFail) { + assertEqCall(targetA, callDataA, targetB, callDataB, strictRevertData); + } +} + +contract TestMockCall { + bytes returnData; + bool shouldRevert; + + constructor(bytes memory returnData_, bool shouldRevert_) { + returnData = returnData_; + shouldRevert = shouldRevert_; + } + + fallback() external payable { + bytes memory returnData_ = returnData; + + if (shouldRevert) { + assembly { + revert(add(returnData_, 0x20), mload(returnData_)) + } + } else { + assembly { + return(add(returnData_, 0x20), mload(returnData_)) + } + } + } }