Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(StdAssertions): Add assertEqCall #311

Merged
merged 4 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/StdAssertions.sol
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,55 @@ abstract contract StdAssertions is DSTest {
assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals);
}
}

function assertEqCall(address target, bytes memory callDataA, bytes memory callDataB) internal virtual {
assertEqCall(target, callDataA, target, callDataB, true);
}

function assertEqCall(address targetA, bytes memory callDataA, address targetB, bytes memory callDataB)
internal
virtual
{
assertEqCall(targetA, callDataA, targetB, callDataB, true);
}

function assertEqCall(address target, bytes memory callDataA, bytes memory callDataB, bool strictRevertData)
internal
virtual
{
assertEqCall(target, callDataA, target, callDataB, strictRevertData);
}

function assertEqCall(
address targetA,
bytes memory callDataA,
address targetB,
bytes memory callDataB,
bool strictRevertData
) internal virtual {
(bool successA, bytes memory returnDataA) = address(targetA).call(callDataA);
(bool successB, bytes memory returnDataB) = address(targetB).call(callDataB);

if (successA && successB) {
assertEq(returnDataA, returnDataB, "Call return data does not match");
}

if (!successA && !successB && strictRevertData) {
assertEq(returnDataA, returnDataB, "Call revert data does not match");
}

if (!successA && successB) {
emit log("Error: Calls were not equal");
emit log_named_bytes(" Left call revert data", returnDataA);
emit log_named_bytes(" Right call return data", returnDataB);
fail();
}

if (successA && !successB) {
emit log("Error: Calls were not equal");
emit log_named_bytes(" Left call return data", returnDataA);
emit log_named_bytes(" Right call revert data", returnDataB);
fail();
}
}
}
131 changes: 131 additions & 0 deletions test/StdAssertions.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ contract StdAssertionsTest is Test {
bool constant EXPECT_PASS = false;
bool constant EXPECT_FAIL = true;

bool constant SHOULD_REVERT = true;
bool constant SHOULD_RETURN = false;

bool constant STRICT_REVERT_DATA = true;
bool constant NON_STRICT_REVERT_DATA = false;

TestTest t = new TestTest();

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -607,6 +613,94 @@ contract StdAssertionsTest is Test {
emit log_named_string("Error", CUSTOM_ERROR);
t._assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals, CUSTOM_ERROR, EXPECT_FAIL);
}

/*//////////////////////////////////////////////////////////////////////////
ASSERT_EQ_CALL
//////////////////////////////////////////////////////////////////////////*/

function testAssertEqCall_Return_Pass(
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnData,
bool strictRevertData
) external {
address targetA = address(new TestMockCall(returnData, SHOULD_RETURN));
address targetB = address(new TestMockCall(returnData, SHOULD_RETURN));

t._assertEqCall(targetA, targetB, callDataA, callDataB, returnData, returnData, strictRevertData, EXPECT_PASS);
}

function testAssertEqCall_Return_Fail(
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnDataA,
bytes memory returnDataB,
bool strictRevertData
) external {
vm.assume(keccak256(returnDataA) != keccak256(returnDataB));

address targetA = address(new TestMockCall(returnDataA, SHOULD_RETURN));
address targetB = address(new TestMockCall(returnDataB, SHOULD_RETURN));

vm.expectEmit(true, true, true, true);
emit log_named_string("Error", "Call return data does not match");
t._assertEqCall(targetA, targetB, callDataA, callDataB, returnDataA, returnDataB, strictRevertData, EXPECT_FAIL);
}

function testAssertEqCall_Revert_Pass(
bytes memory callDataA,
bytes memory callDataB,
bytes memory revertDataA,
bytes memory revertDataB
) external {
address targetA = address(new TestMockCall(revertDataA, SHOULD_REVERT));
address targetB = address(new TestMockCall(revertDataB, SHOULD_REVERT));

t._assertEqCall(
targetA, targetB, callDataA, callDataB, revertDataA, revertDataB, NON_STRICT_REVERT_DATA, EXPECT_PASS
);
}

function testAssertEqCall_Revert_Fail(
bytes memory callDataA,
bytes memory callDataB,
bytes memory revertDataA,
bytes memory revertDataB
) external {
vm.assume(keccak256(revertDataA) != keccak256(revertDataB));

address targetA = address(new TestMockCall(revertDataA, SHOULD_REVERT));
address targetB = address(new TestMockCall(revertDataB, SHOULD_REVERT));

vm.expectEmit(true, true, true, true);
emit log_named_string("Error", "Call revert data does not match");
t._assertEqCall(
targetA, targetB, callDataA, callDataB, revertDataA, revertDataB, STRICT_REVERT_DATA, EXPECT_FAIL
);
}

function testAssertEqCall_Fail(
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnDataA,
bytes memory returnDataB,
bool strictRevertData
) external {
address targetA = address(new TestMockCall(returnDataA, SHOULD_RETURN));
address targetB = address(new TestMockCall(returnDataB, SHOULD_REVERT));

vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Left call return data", returnDataA);
vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Right call revert data", returnDataB);
t._assertEqCall(targetA, targetB, callDataA, callDataB, returnDataA, returnDataB, strictRevertData, EXPECT_FAIL);

vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Left call revert data", returnDataB);
vm.expectEmit(true, true, true, true);
emit log_named_bytes(" Right call return data", returnDataA);
t._assertEqCall(targetB, targetA, callDataB, callDataA, returnDataB, returnDataA, strictRevertData, EXPECT_FAIL);
}
}

contract TestTest is Test {
Expand Down Expand Up @@ -820,4 +914,41 @@ contract TestTest is Test {
) external expectFailure(expectFail) {
assertApproxEqRelDecimal(a, b, maxPercentDelta, decimals, err);
}

function _assertEqCall(
address targetA,
address targetB,
bytes memory callDataA,
bytes memory callDataB,
bytes memory returnDataA,
bytes memory returnDataB,
bool strictRevertData,
bool expectFail
) external expectFailure(expectFail) {
assertEqCall(targetA, callDataA, targetB, callDataB, strictRevertData);
}
}

contract TestMockCall {
bytes returnData;
bool shouldRevert;

constructor(bytes memory returnData_, bool shouldRevert_) {
returnData = returnData_;
shouldRevert = shouldRevert_;
}

fallback() external payable {
bytes memory returnData_ = returnData;

if (shouldRevert) {
assembly {
revert(add(returnData_, 0x20), mload(returnData_))
}
} else {
assembly {
return(add(returnData_, 0x20), mload(returnData_))
}
}
}
}