Skip to content

Commit

Permalink
feat(StdAssertions): Add assertEqCall (#311)
Browse files Browse the repository at this point in the history
* Add `assertEqCall` to StdAssertions

* Update src/StdAssertions.sol

Co-authored-by: Matt Solomon <[email protected]>

* Update src/StdAssertions.sol

Co-authored-by: Matt Solomon <[email protected]>

* Update StdAssertions.t.sol

---------

Co-authored-by: Matt Solomon <[email protected]>
  • Loading branch information
0xPhaze and mds1 authored Feb 28, 2023
1 parent 65761f0 commit 653eff5
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/StdAssertions.sol
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,55 @@ abstract contract StdAssertions is DSTest {
assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals);
}
}

function assertEqCall(address target, bytes memory callDataA, bytes memory callDataB) internal virtual {
assertEqCall(target, callDataA, target, callDataB, true);
}

function assertEqCall(address targetA, bytes memory callDataA, address targetB, bytes memory callDataB)
internal
virtual
{
assertEqCall(targetA, callDataA, targetB, callDataB, true);
}

function assertEqCall(address target, bytes memory callDataA, bytes memory callDataB, bool strictRevertData)
internal
virtual
{
assertEqCall(target, callDataA, target, callDataB, strictRevertData);
}

function assertEqCall(
address targetA,
bytes memory callDataA,
address targetB,
bytes memory callDataB,
bool strictRevertData
) internal virtual {
(bool successA, bytes memory returnDataA) = address(targetA).call(callDataA);
(bool successB, bytes memory returnDataB) = address(targetB).call(callDataB);

if (successA && successB) {
assertEq(returnDataA, returnDataB, "Call return data does not match");
}

if (!successA && !successB && strictRevertData) {
assertEq(returnDataA, returnDataB, "Call revert data does not match");
}

if (!successA && successB) {
emit log("Error: Calls were not equal");
emit log_named_bytes(" Left call revert data", returnDataA);
emit log_named_bytes(" Right call return data", returnDataB);
fail();
}

if (successA && !successB) {
emit log("Error: Calls were not equal");
emit log_named_bytes(" Left call return data", returnDataA);
emit log_named_bytes(" Right call revert data", returnDataB);
fail();
}
}
}
131 changes: 131 additions & 0 deletions test/StdAssertions.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ contract StdAssertionsTest is Test {
bool constant EXPECT_PASS = false;
bool constant EXPECT_FAIL = true;

bool constant SHOULD_REVERT = true;
bool constant SHOULD_RETURN = false;

bool constant STRICT_REVERT_DATA = true;
bool constant NON_STRICT_REVERT_DATA = false;

TestTest t = new TestTest();

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -607,6 +613,94 @@ contract StdAssertionsTest is Test {
emit log_named_string("Error", CUSTOM_ERROR);
t._assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals, CUSTOM_ERROR, EXPECT_FAIL);
}

/*//////////////////////////////////////////////////////////////////////////
ASSERT_EQ_CALL
//////////////////////////////////////////////////////////////////////////*/

function testAssertEqCall_Return_Pass(
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnData,
bool strictRevertData
) external {
address targetA = address(new TestMockCall(returnData, SHOULD_RETURN));
address targetB = address(new TestMockCall(returnData, SHOULD_RETURN));

t._assertEqCall(targetA, targetB, callDataA, callDataB, returnData, returnData, strictRevertData, EXPECT_PASS);
}

function testAssertEqCall_Return_Fail(
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnDataA,
bytes memory returnDataB,
bool strictRevertData
) external {
vm.assume(keccak256(returnDataA) != keccak256(returnDataB));

address targetA = address(new TestMockCall(returnDataA, SHOULD_RETURN));
address targetB = address(new TestMockCall(returnDataB, SHOULD_RETURN));

vm.expectEmit(true, true, true, true);
emit log_named_string("Error", "Call return data does not match");
t._assertEqCall(targetA, targetB, callDataA, callDataB, returnDataA, returnDataB, strictRevertData, EXPECT_FAIL);
}

function testAssertEqCall_Revert_Pass(
bytes memory callDataA,
bytes memory callDataB,
bytes memory revertDataA,
bytes memory revertDataB
) external {
address targetA = address(new TestMockCall(revertDataA, SHOULD_REVERT));
address targetB = address(new TestMockCall(revertDataB, SHOULD_REVERT));

t._assertEqCall(
targetA, targetB, callDataA, callDataB, revertDataA, revertDataB, NON_STRICT_REVERT_DATA, EXPECT_PASS
);
}

function testAssertEqCall_Revert_Fail(
bytes memory callDataA,
bytes memory callDataB,
bytes memory revertDataA,
bytes memory revertDataB
) external {
vm.assume(keccak256(revertDataA) != keccak256(revertDataB));

address targetA = address(new TestMockCall(revertDataA, SHOULD_REVERT));
address targetB = address(new TestMockCall(revertDataB, SHOULD_REVERT));

vm.expectEmit(true, true, true, true);
emit log_named_string("Error", "Call revert data does not match");
t._assertEqCall(
targetA, targetB, callDataA, callDataB, revertDataA, revertDataB, STRICT_REVERT_DATA, EXPECT_FAIL
);
}

function testAssertEqCall_Fail(
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnDataA,
bytes memory returnDataB,
bool strictRevertData
) external {
address targetA = address(new TestMockCall(returnDataA, SHOULD_RETURN));
address targetB = address(new TestMockCall(returnDataB, SHOULD_REVERT));

vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Left call return data", returnDataA);
vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Right call revert data", returnDataB);
t._assertEqCall(targetA, targetB, callDataA, callDataB, returnDataA, returnDataB, strictRevertData, EXPECT_FAIL);

vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Left call revert data", returnDataB);
vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Right call return data", returnDataA);
t._assertEqCall(targetB, targetA, callDataB, callDataA, returnDataB, returnDataA, strictRevertData, EXPECT_FAIL);
}
}

contract TestTest is Test {
Expand Down Expand Up @@ -820,4 +914,41 @@ contract TestTest is Test {
) external expectFailure(expectFail) {
assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals, err);
}

function _assertEqCall(
address targetA,
address targetB,
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnDataA,
bytes memory returnDataB,
bool strictRevertData,
bool expectFail
) external expectFailure(expectFail) {
assertEqCall(targetA, callDataA, targetB, callDataB, strictRevertData);
}
}

contract TestMockCall {
bytes returnData;
bool shouldRevert;

constructor(bytes memory returnData_, bool shouldRevert_) {
returnData = returnData_;
shouldRevert = shouldRevert_;
}

fallback() external payable {
bytes memory returnData_ = returnData;

if (shouldRevert) {
assembly {
revert(add(returnData_, 0x20), mload(returnData_))
}
} else {
assembly {
return(add(returnData_, 0x20), mload(returnData_))
}
}
}
}

0 comments on commit 653eff5

Please sign in to comment.