From acc9eae2fb38c836983582e09af21765f906b40a Mon Sep 17 00:00:00 2001 From: Shebin John Date: Mon, 28 Oct 2024 14:02:11 +0100 Subject: [PATCH 1/3] Extensible Fallback Handler --- .../handler/ExtensibleFallbackHandler.sol | 28 + contracts/handler/HandlerContext.sol | 1 + .../handler/extensible/ERC165Handler.sol | 113 +++ .../handler/extensible/ExtensibleBase.sol | 86 ++ .../handler/extensible/FallbackHandler.sol | 42 + contracts/handler/extensible/MarshalLib.sol | 44 + .../extensible/SignatureVerifierMuxer.sol | 175 ++++ .../handler/extensible/TokenCallbacks.sol | 47 ++ contracts/test/TestMarshalLib.sol | 25 + contracts/test/TestSafeSignatureVerifier.sol | 31 + src/deploy/deploy_handlers.ts | 7 + .../ExtensibleFallbackHandler.spec.ts | 790 ++++++++++++++++++ test/handlers/HandlerContext.spec.ts | 15 + test/utils/extensible.ts | 69 ++ test/utils/setup.ts | 11 + 15 files changed, 1484 insertions(+) create mode 100644 contracts/handler/ExtensibleFallbackHandler.sol create mode 100644 contracts/handler/extensible/ERC165Handler.sol create mode 100644 contracts/handler/extensible/ExtensibleBase.sol create mode 100644 contracts/handler/extensible/FallbackHandler.sol create mode 100644 contracts/handler/extensible/MarshalLib.sol create mode 100644 contracts/handler/extensible/SignatureVerifierMuxer.sol create mode 100644 contracts/handler/extensible/TokenCallbacks.sol create mode 100644 contracts/test/TestMarshalLib.sol create mode 100644 contracts/test/TestSafeSignatureVerifier.sol create mode 100644 test/handlers/ExtensibleFallbackHandler.spec.ts create mode 100644 test/utils/extensible.ts diff --git a/contracts/handler/ExtensibleFallbackHandler.sol b/contracts/handler/ExtensibleFallbackHandler.sol new file mode 100644 index 000000000..7843af4b2 --- /dev/null +++ b/contracts/handler/ExtensibleFallbackHandler.sol @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {ERC165Handler} from "./extensible/ERC165Handler.sol"; +import {IFallbackHandler, FallbackHandler} from "./extensible/FallbackHandler.sol"; +import {ERC1271, ISignatureVerifierMuxer, SignatureVerifierMuxer} from "./extensible/SignatureVerifierMuxer.sol"; +import {ERC721TokenReceiver, ERC1155TokenReceiver, TokenCallbacks} from "./extensible/TokenCallbacks.sol"; + +/** + * @title ExtensibleFallbackHandler - A fully extensible fallback handler for Safes + * @dev Designed to be used with Safe >= 1.3.0. + * @author mfw78 + */ +contract ExtensibleFallbackHandler is FallbackHandler, SignatureVerifierMuxer, TokenCallbacks, ERC165Handler { + /** + * Specify specific interfaces (ERC721 + ERC1155) that this contract supports. + * @param interfaceId The interface ID to check for support + */ + function _supportsInterface(bytes4 interfaceId) internal pure override returns (bool) { + return + interfaceId == type(ERC1271).interfaceId || + interfaceId == type(ISignatureVerifierMuxer).interfaceId || + interfaceId == type(ERC165Handler).interfaceId || + interfaceId == type(IFallbackHandler).interfaceId || + interfaceId == type(ERC721TokenReceiver).interfaceId || + interfaceId == type(ERC1155TokenReceiver).interfaceId; + } +} diff --git a/contracts/handler/HandlerContext.sol b/contracts/handler/HandlerContext.sol index 8274cf69a..ec19f3f82 100644 --- a/contracts/handler/HandlerContext.sol +++ b/contracts/handler/HandlerContext.sol @@ -18,6 +18,7 @@ abstract contract HandlerContext { * @return sender Original caller address. */ function _msgSender() internal pure returns (address sender) { + require(msg.data.length >= 20, "Invalid calldata length"); // The assembly code is more direct than the Solidity version using `abi.decode`. /* solhint-disable no-inline-assembly */ /// @solidity memory-safe-assembly diff --git a/contracts/handler/extensible/ERC165Handler.sol b/contracts/handler/extensible/ERC165Handler.sol new file mode 100644 index 000000000..a5f85b9cc --- /dev/null +++ b/contracts/handler/extensible/ERC165Handler.sol @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {IERC165} from "../../interfaces/IERC165.sol"; +import {Safe, MarshalLib, ExtensibleBase} from "./ExtensibleBase.sol"; + +interface IERC165Handler { + function safeInterfaces(Safe safe, bytes4 interfaceId) external view returns (bool); + + function setSupportedInterface(bytes4 interfaceId, bool supported) external; + + function addSupportedInterfaceBatch(bytes4 interfaceId, bytes32[] calldata handlerWithSelectors) external; + + function removeSupportedInterfaceBatch(bytes4 interfaceId, bytes4[] calldata selectors) external; +} + +abstract contract ERC165Handler is ExtensibleBase, IERC165Handler { + // --- events --- + + event AddedInterface(Safe indexed safe, bytes4 interfaceId); + event RemovedInterface(Safe indexed safe, bytes4 interfaceId); + + // --- storage --- + + mapping(Safe => mapping(bytes4 => bool)) public override safeInterfaces; + + // --- setters --- + + /** + * Setter to indicate if an interface is supported (and thus reported by ERC165 supportsInterface) + * @param interfaceId The interface id whose support is to be set + * @param supported True if the interface is supported, false otherwise + */ + function setSupportedInterface(bytes4 interfaceId, bool supported) public override onlySelf { + Safe safe = Safe(payable(_manager())); + // invalid interface id per ERC165 spec + require(interfaceId != 0xffffffff, "invalid interface id"); + bool current = safeInterfaces[safe][interfaceId]; + if (supported && !current) { + safeInterfaces[safe][interfaceId] = true; + emit AddedInterface(safe, interfaceId); + } else if (!supported && current) { + delete safeInterfaces[safe][interfaceId]; + emit RemovedInterface(safe, interfaceId); + } + } + + /** + * Batch add selectors for an interface. + * @param _interfaceId The interface id to set + * @param handlerWithSelectors The handlers encoded with the 4-byte selectors of the methods + */ + function addSupportedInterfaceBatch(bytes4 _interfaceId, bytes32[] calldata handlerWithSelectors) external override onlySelf { + Safe safe = Safe(payable(_msgSender())); + bytes4 interfaceId; + for (uint256 i = 0; i < handlerWithSelectors.length; i++) { + (bool isStatic, bytes4 selector, address handlerAddress) = MarshalLib.decodeWithSelector(handlerWithSelectors[i]); + _setSafeMethod(safe, selector, MarshalLib.encode(isStatic, handlerAddress)); + if (i > 0) { + interfaceId ^= selector; + } else { + interfaceId = selector; + } + } + + require(interfaceId == _interfaceId, "interface id mismatch"); + setSupportedInterface(_interfaceId, true); + } + + /** + * Batch remove selectors for an interface. + * @param _interfaceId the interface id to remove + * @param selectors The selectors of the methods to remove + */ + function removeSupportedInterfaceBatch(bytes4 _interfaceId, bytes4[] calldata selectors) external override onlySelf { + Safe safe = Safe(payable(_msgSender())); + bytes4 interfaceId; + for (uint256 i = 0; i < selectors.length; i++) { + _setSafeMethod(safe, selectors[i], bytes32(0)); + if (i > 0) { + interfaceId ^= selectors[i]; + } else { + interfaceId = selectors[i]; + } + } + + require(interfaceId == _interfaceId, "interface id mismatch"); + setSupportedInterface(_interfaceId, false); + } + + /** + * @notice Implements ERC165 interface detection for the supported interfaces + * @dev Inheriting contracts should override `_supportsInterface` to add support for additional interfaces + * @param interfaceId The ERC165 interface id to check + * @return True if the interface is supported + */ + function supportsInterface(bytes4 interfaceId) external view returns (bool) { + return + interfaceId == type(IERC165).interfaceId || + interfaceId == type(IERC165Handler).interfaceId || + _supportsInterface(interfaceId) || + safeInterfaces[Safe(payable(_manager()))][interfaceId]; + } + + // --- internal --- + + /** + * A stub function to be overridden by inheriting contracts to add support for additional interfaces + * @param interfaceId The interface id to check support for + * @return True if the interface is supported + */ + function _supportsInterface(bytes4 interfaceId) internal view virtual returns (bool); +} diff --git a/contracts/handler/extensible/ExtensibleBase.sol b/contracts/handler/extensible/ExtensibleBase.sol new file mode 100644 index 000000000..9aa9d301e --- /dev/null +++ b/contracts/handler/extensible/ExtensibleBase.sol @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {Safe} from "../../Safe.sol"; +import {HandlerContext} from "../HandlerContext.sol"; +import {MarshalLib} from "./MarshalLib.sol"; + +interface IFallbackMethod { + function handle(Safe safe, address sender, uint256 value, bytes calldata data) external returns (bytes memory result); +} + +interface IStaticFallbackMethod { + function handle(Safe safe, address sender, uint256 value, bytes calldata data) external view returns (bytes memory result); +} + +/** + * @title Base contract for Extensible Fallback Handlers + * @dev This contract provides the base for storage and modifiers for extensible fallback handlers + * @author mfw78 + */ +abstract contract ExtensibleBase is HandlerContext { + // --- events --- + event AddedSafeMethod(Safe indexed safe, bytes4 selector, bytes32 method); + event ChangedSafeMethod(Safe indexed safe, bytes4 selector, bytes32 oldMethod, bytes32 newMethod); + event RemovedSafeMethod(Safe indexed safe, bytes4 selector); + + // --- storage --- + + // A mapping of Safe => selector => method + // The method is a bytes32 that is encoded as follows: + // - The first byte is 0x00 if the method is static and 0x01 if the method is not static + // - The last 20 bytes are the address of the handler contract + // The method is encoded / decoded using the MarshalLib + mapping(Safe => mapping(bytes4 => bytes32)) public safeMethods; + + // --- modifiers --- + modifier onlySelf() { + // Use the `HandlerContext._msgSender()` to get the caller of the fallback function + // Use the `HandlerContext._manager()` to get the manager, which should be the Safe + // Require that the caller is the Safe itself + require(_msgSender() == _manager(), "only safe can call this method"); + _; + } + + // --- internal --- + + function _setSafeMethod(Safe safe, bytes4 selector, bytes32 newMethod) internal { + (, address newHandler) = MarshalLib.decode(newMethod); + bytes32 oldMethod = safeMethods[safe][selector]; + (, address oldHandler) = MarshalLib.decode(oldMethod); + + if (address(newHandler) == address(0) && address(oldHandler) != address(0)) { + delete safeMethods[safe][selector]; + emit RemovedSafeMethod(safe, selector); + } else { + safeMethods[safe][selector] = newMethod; + if (address(oldHandler) == address(0)) { + emit AddedSafeMethod(safe, selector, newMethod); + } else { + emit ChangedSafeMethod(safe, selector, oldMethod, newMethod); + } + } + } + + /** + * Dry code to get the Safe and the original `msg.sender` from the FallbackManager + * @return safe The safe whose FallbackManager is making this call + * @return sender The original `msg.sender` (as received by the FallbackManager) + */ + function _getContext() internal view returns (Safe safe, address sender) { + safe = Safe(payable(_manager())); + sender = _msgSender(); + } + + /** + * Get the context and the method handler applicable to the current call + * @return safe The safe whose FallbackManager is making this call + * @return sender The original `msg.sender` (as received by the FallbackManager) + * @return isStatic Whether the method is static (`view`) or not + * @return handler the address of the handler contract + */ + function _getContextAndHandler() internal view returns (Safe safe, address sender, bool isStatic, address handler) { + (safe, sender) = _getContext(); + (isStatic, handler) = MarshalLib.decode(safeMethods[safe][msg.sig]); + } +} diff --git a/contracts/handler/extensible/FallbackHandler.sol b/contracts/handler/extensible/FallbackHandler.sol new file mode 100644 index 000000000..8565c30d5 --- /dev/null +++ b/contracts/handler/extensible/FallbackHandler.sol @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {Safe, IStaticFallbackMethod, IFallbackMethod, ExtensibleBase} from "./ExtensibleBase.sol"; + +interface IFallbackHandler { + function setSafeMethod(bytes4 selector, bytes32 newMethod) external; +} + +/** + * @title FallbackHandler - A fully extensible fallback handler for Safes + * @dev This contract provides a fallback handler for Safes that can be extended with custom fallback handlers + * for specific methods. + * @author mfw78 + */ +abstract contract FallbackHandler is ExtensibleBase, IFallbackHandler { + // --- setters --- + + /** + * Setter for custom method handlers + * @param selector The `bytes4` selector of the method to set the handler for + * @param newMethod A contract that implements the `IFallbackMethod` or `IStaticFallbackMethod` interface + */ + function setSafeMethod(bytes4 selector, bytes32 newMethod) public override onlySelf { + _setSafeMethod(Safe(payable(_msgSender())), selector, newMethod); + } + + // --- fallback --- + + // solhint-disable-next-line + fallback(bytes calldata) external returns (bytes memory result) { + require(msg.data.length >= 24, "invalid method selector"); + (Safe safe, address sender, bool isStatic, address handler) = _getContextAndHandler(); + require(handler != address(0), "method handler not set"); + + if (isStatic) { + result = IStaticFallbackMethod(handler).handle(safe, sender, 0, msg.data[:msg.data.length - 20]); + } else { + result = IFallbackMethod(handler).handle(safe, sender, 0, msg.data[:msg.data.length - 20]); + } + } +} diff --git a/contracts/handler/extensible/MarshalLib.sol b/contracts/handler/extensible/MarshalLib.sol new file mode 100644 index 000000000..161aeeaac --- /dev/null +++ b/contracts/handler/extensible/MarshalLib.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +library MarshalLib { + /** + * Encode a method handler into a `bytes32` value + * @dev The first byte of the `bytes32` value is set to 0x01 if the method is not static (`view`) + * @dev The last 20 bytes of the `bytes32` value are set to the address of the handler contract + * @param isStatic Whether the method is static (`view`) or not + * @param handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface + */ + function encode(bool isStatic, address handler) internal pure returns (bytes32 data) { + data = bytes32(uint256(uint160(handler)) | (isStatic ? 0 : (1 << 248))); + } + + function encodeWithSelector(bool isStatic, bytes4 selector, address handler) internal pure returns (bytes32 data) { + data = bytes32(uint256(uint160(handler)) | (isStatic ? 0 : (1 << 248)) | (uint256(uint32(selector)) << 216)); + } + + /** + * Given a `bytes32` value, decode it into a method handler and return it + * @param data The packed data to decode + * @return isStatic Whether the method is static (`view`) or not + * @return handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface + */ + function decode(bytes32 data) internal pure returns (bool isStatic, address handler) { + // solhint-disable-next-line no-inline-assembly + assembly { + // set isStatic to true if the left-most byte of the data is 0x00 + isStatic := iszero(shr(248, data)) + handler := shr(96, shl(96, data)) + } + } + + function decodeWithSelector(bytes32 data) internal pure returns (bool isStatic, bytes4 selector, address handler) { + // solhint-disable-next-line no-inline-assembly + assembly { + // set isStatic to true if the left-most byte of the data is 0x00 + isStatic := iszero(shr(248, data)) + handler := shr(96, shl(96, data)) + selector := shl(168, shr(160, data)) + } + } +} diff --git a/contracts/handler/extensible/SignatureVerifierMuxer.sol b/contracts/handler/extensible/SignatureVerifierMuxer.sol new file mode 100644 index 000000000..e5391ad1a --- /dev/null +++ b/contracts/handler/extensible/SignatureVerifierMuxer.sol @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {Safe, ExtensibleBase} from "./ExtensibleBase.sol"; + +interface ERC1271 { + function isValidSignature(bytes32 hash, bytes calldata signature) external view returns (bytes4 magicValue); +} + +/** + * @title Safe Signature Verifier Interface + * @author mfw78 + * @notice This interface provides an standard for external contracts that are verifying signatures + * for a Safe. + */ +interface ISafeSignatureVerifier { + /** + * @dev If called by `SignatureVerifierMuxer`, the following has already been checked: + * _hash = h(abi.encodePacked("\x19\x01", domainSeparator, h(typeHash || encodeData))); + * @param safe The Safe that has delegated the signature verification + * @param sender The address that originally called the Safe's `isValidSignature` method + * @param _hash The EIP-712 hash whose signature will be verified + * @param domainSeparator The EIP-712 domainSeparator + * @param typeHash The EIP-712 typeHash + * @param encodeData The EIP-712 encoded data + * @param payload An arbitrary payload that can be used to pass additional data to the verifier + * @return magic The magic value that should be returned if the signature is valid (0x1626ba7e) + */ + function isValidSafeSignature( + Safe safe, + address sender, + bytes32 _hash, + bytes32 domainSeparator, + bytes32 typeHash, + bytes calldata encodeData, + bytes calldata payload + ) external view returns (bytes4 magic); +} + +interface ISignatureVerifierMuxer { + function domainVerifiers(Safe safe, bytes32 domainSeparator) external view returns (ISafeSignatureVerifier); + + function setDomainVerifier(bytes32 domainSeparator, ISafeSignatureVerifier verifier) external; +} + +/** + * @title ERC-1271 Signature Verifier Multiplexer (Muxer) + * @author mfw78 + * @notice Allows delegating EIP-712 domains to an arbitrary `ISafeSignatureVerifier` + * @dev This multiplexer enforces a strict authorisation per domainSeparator. This is to prevent a malicious + * `ISafeSignatureVerifier` from being able to verify signatures for any domainSeparator. This does not prevent + * an `ISafeSignatureVerifier` from being able to verify signatures for multiple domainSeparators, however + * each domainSeparator requires specific approval by Safe. + */ +abstract contract SignatureVerifierMuxer is ExtensibleBase, ERC1271, ISignatureVerifierMuxer { + // --- constants --- + // keccak256("SafeMessage(bytes message)"); + bytes32 private constant SAFE_MSG_TYPEHASH = 0x60b3cbf8b4a223d68d641b3b6ddf9a298e7f33710cf3d3a9d1146b5a6150fbca; + // keccak256("safeSignature(bytes32,bytes32,bytes,bytes)"); + bytes4 private constant SAFE_SIGNATURE_MAGIC_VALUE = 0x5fd7e97d; + + // --- storage --- + mapping(Safe => mapping(bytes32 => ISafeSignatureVerifier)) public override domainVerifiers; + + // --- events --- + event AddedDomainVerifier(Safe indexed safe, bytes32 domainSeparator, ISafeSignatureVerifier verifier); + event ChangedDomainVerifier( + Safe indexed safe, + bytes32 domainSeparator, + ISafeSignatureVerifier oldVerifier, + ISafeSignatureVerifier newVerifier + ); + event RemovedDomainVerifier(Safe indexed safe, bytes32 domainSeparator); + + /** + * Setter for the signature muxer + * @param domainSeparator The domainSeparator authorised for the `ISafeSignatureVerifier` + * @param newVerifier A contract that implements `ISafeSignatureVerifier` + */ + function setDomainVerifier(bytes32 domainSeparator, ISafeSignatureVerifier newVerifier) public override onlySelf { + Safe safe = Safe(payable(_msgSender())); + ISafeSignatureVerifier oldVerifier = domainVerifiers[safe][domainSeparator]; + if (address(newVerifier) == address(0) && address(oldVerifier) != address(0)) { + delete domainVerifiers[safe][domainSeparator]; + emit RemovedDomainVerifier(safe, domainSeparator); + } else { + domainVerifiers[safe][domainSeparator] = newVerifier; + if (address(oldVerifier) == address(0)) { + emit AddedDomainVerifier(safe, domainSeparator, newVerifier); + } else { + emit ChangedDomainVerifier(safe, domainSeparator, oldVerifier, newVerifier); + } + } + } + + /** + * @notice Implements ERC1271 interface for smart contract EIP-712 signature validation + * @dev The signature format is the same as the one used by the Safe contract + * @param _hash Hash of the data that is signed + * @param signature The signature to be verified + * @return magic Standardised ERC1271 return value + */ + function isValidSignature(bytes32 _hash, bytes calldata signature) external view override returns (bytes4 magic) { + (Safe safe, address sender) = _getContext(); + + // Check if the signature is for an `ISafeSignatureVerifier` and if it is valid for the domain. + if (signature.length >= 4) { + bytes4 sigSelector; + // solhint-disable-next-line no-inline-assembly + assembly { + sigSelector := shl(224, shr(224, calldataload(signature.offset))) + } + + // Guard against short signatures that would cause abi.decode to revert. + if (sigSelector == SAFE_SIGNATURE_MAGIC_VALUE && signature.length >= 68) { + // Signature is for an `ISafeSignatureVerifier` - decode the signature. + // Layout of the `signature`: + // 0x00 - 0x04: selector + // 0x04 - 0x36: domainSeparator + // 0x36 - 0x68: typeHash + // 0x68 - 0x6C: encodeData length + // 0x6C - 0x6C + encodeData length: encodeData + // 0x6C + encodeData length - 0x6C + encodeData length + 0x20: payload length + // 0x6C + encodeData length + 0x20 - end: payload + // + // Get the domainSeparator from the signature. + (bytes32 domainSeparator, bytes32 typeHash) = abi.decode(signature[4:68], (bytes32, bytes32)); + + ISafeSignatureVerifier verifier = domainVerifiers[safe][domainSeparator]; + // Check if there is an `ISafeSignatureVerifier` for the domain. + if (address(verifier) != address(0)) { + (, , bytes memory encodeData, bytes memory payload) = abi.decode(signature[4:], (bytes32, bytes32, bytes, bytes)); + + // Check that the signature is valid for the domain. + if (keccak256(EIP712.encodeMessageData(domainSeparator, typeHash, encodeData)) == _hash) { + // Preserving the context, call the Safe's authorised `ISafeSignatureVerifier` to verify. + return verifier.isValidSafeSignature(safe, sender, _hash, domainSeparator, typeHash, encodeData, payload); + } + } + } + } + + // domainVerifier doesn't exist or the signature is invalid for the domain - fall back to the default + return defaultIsValidSignature(safe, _hash, signature); + } + + /** + * Default Safe signature validation (approved hashes / threshold signatures) + * @param safe The safe being asked to validate the signature + * @param _hash Hash of the data that is signed + * @param signature The signature to be verified + */ + function defaultIsValidSignature(Safe safe, bytes32 _hash, bytes memory signature) internal view returns (bytes4 magic) { + bytes memory messageData = EIP712.encodeMessageData( + safe.domainSeparator(), + SAFE_MSG_TYPEHASH, + abi.encode(keccak256(abi.encode(_hash))) + ); + bytes32 messageHash = keccak256(messageData); + if (signature.length == 0) { + // approved hashes + require(safe.signedMessages(messageHash) != 0, "Hash not approved"); + } else { + // threshold signatures + safe.checkSignatures(messageHash, messageData, signature); + } + magic = ERC1271.isValidSignature.selector; + } +} + +library EIP712 { + function encodeMessageData(bytes32 domainSeparator, bytes32 typeHash, bytes memory message) internal pure returns (bytes memory) { + return abi.encodePacked(bytes1(0x19), bytes1(0x01), domainSeparator, keccak256(abi.encodePacked(typeHash, message))); + } +} diff --git a/contracts/handler/extensible/TokenCallbacks.sol b/contracts/handler/extensible/TokenCallbacks.sol new file mode 100644 index 000000000..9bb069fe9 --- /dev/null +++ b/contracts/handler/extensible/TokenCallbacks.sol @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {ERC1155TokenReceiver} from "../../interfaces/ERC1155TokenReceiver.sol"; +import {ERC721TokenReceiver} from "../../interfaces/ERC721TokenReceiver.sol"; + +import {ExtensibleBase} from "./ExtensibleBase.sol"; + +/** + * @title TokenCallbacks - ERC1155 and ERC721 token callbacks for Safes + * @author mfw78 + * @notice Refactored from https://github.com/safe-global/safe-contracts/blob/3c3fc80f7f9aef1d39aaae2b53db5f4490051b0d/contracts/handler/TokenCallbackHandler.sol + */ +abstract contract TokenCallbacks is ExtensibleBase, ERC1155TokenReceiver, ERC721TokenReceiver { + /** + * @notice Handles ERC1155 Token callback. + * return Standardized onERC1155Received return value. + */ + function onERC1155Received(address, address, uint256, uint256, bytes calldata) external pure override returns (bytes4) { + // Else return the standard value + return 0xf23a6e61; + } + + /** + * @notice Handles ERC1155 Token batch callback. + * return Standardized onERC1155BatchReceived return value. + */ + function onERC1155BatchReceived( + address, + address, + uint256[] calldata, + uint256[] calldata, + bytes calldata + ) external pure override returns (bytes4) { + // Else return the standard value + return 0xbc197c81; + } + + /** + * @notice Handles ERC721 Token callback. + * return Standardized onERC721Received return value. + */ + function onERC721Received(address, address, uint256, bytes calldata) external pure override returns (bytes4) { + // Else return the standard value + return 0x150b7a02; + } +} diff --git a/contracts/test/TestMarshalLib.sol b/contracts/test/TestMarshalLib.sol new file mode 100644 index 000000000..b227a90a3 --- /dev/null +++ b/contracts/test/TestMarshalLib.sol @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {MarshalLib} from "../handler/extensible/MarshalLib.sol"; + +/** + * @title TestMarshalLib - A test contract for MarshalLib + */ +contract TestMarshalLib { + function encode(bool isStatic, address handler) external pure returns (bytes32 data) { + return MarshalLib.encode(isStatic, handler); + } + + function encodeWithSelector(bool isStatic, bytes4 selector, address handler) external pure returns (bytes32 data) { + return MarshalLib.encodeWithSelector(isStatic, selector, handler); + } + + function decode(bytes32 data) external pure returns (bool isStatic, address handler) { + return MarshalLib.decode(data); + } + + function decodeWithSelector(bytes32 data) external pure returns (bool isStatic, bytes4 selector, address handler) { + return MarshalLib.decodeWithSelector(data); + } +} diff --git a/contracts/test/TestSafeSignatureVerifier.sol b/contracts/test/TestSafeSignatureVerifier.sol new file mode 100644 index 000000000..5b8020358 --- /dev/null +++ b/contracts/test/TestSafeSignatureVerifier.sol @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: LGPL-3.0-only +pragma solidity >=0.7.0 <0.9.0; + +import {Safe, EIP712, ISafeSignatureVerifier} from "../handler/extensible/SignatureVerifierMuxer.sol"; + +/** + * @title TestSafeSignatureVerifier - A simple test contract that implements the ISafeSignatureVerifier interface + */ +contract TestSafeSignatureVerifier is ISafeSignatureVerifier { + /** + * Validates a signature for a Safe. + * @param _hash of the message to verify + * @param domainSeparator of the message to verify + * @param typeHash of the message to verify + * @param encodeData of the message to verify + * @return magic The magic value that should be returned if the signature is valid (0x1626ba7e) + */ + function isValidSafeSignature( + Safe, + address, + bytes32 _hash, + bytes32 domainSeparator, + bytes32 typeHash, + bytes calldata encodeData, + bytes calldata + ) external pure override returns (bytes4 magic) { + if (_hash == keccak256(EIP712.encodeMessageData(domainSeparator, typeHash, encodeData))) { + return 0x1626ba7e; + } + } +} diff --git a/src/deploy/deploy_handlers.ts b/src/deploy/deploy_handlers.ts index c3aa59fda..55095f92c 100644 --- a/src/deploy/deploy_handlers.ts +++ b/src/deploy/deploy_handlers.ts @@ -20,6 +20,13 @@ const deploy: DeployFunction = async function (hre: HardhatRuntimeEnvironment) { log: true, deterministicDeployment: true, }); + + await deploy("ExtensibleFallbackHandler", { + from: deployerAccount, + args: [], + log: true, + deterministicDeployment: true, + }); }; deploy.tags = ["handlers", "l2-suite", "main-suite"]; diff --git a/test/handlers/ExtensibleFallbackHandler.spec.ts b/test/handlers/ExtensibleFallbackHandler.spec.ts new file mode 100644 index 000000000..df7cee4e6 --- /dev/null +++ b/test/handlers/ExtensibleFallbackHandler.spec.ts @@ -0,0 +1,790 @@ +import { expect } from "chai"; +import hre, { deployments, ethers } from "hardhat"; +import { AddressZero, HashZero } from "@ethersproject/constants"; +import { deployContractFromSource, getExtensibleFallbackHandler, getSafe } from "../utils/setup"; +import { buildSignatureBytes, executeContractCallWithSigners, EIP712_SAFE_MESSAGE_TYPE } from "../../src/utils/execution"; +import { chainId } from "../utils/encoding"; +import { encodeHandler, decodeHandler, encodeCustomVerifier, encodeHandlerFunction } from "../utils/extensible"; +import { killLibContract } from "../utils/contracts"; + +describe("ExtensibleFallbackHandler", async () => { + const [user1, user2] = await hre.ethers.getSigners(); + + const setupTests = deployments.createFixture(async ({ deployments }) => { + await deployments.fixture(); + const signLib = await (await hre.ethers.getContractFactory("SignMessageLib")).deploy(); + const handler = await getExtensibleFallbackHandler(); + const handlerAddress = await handler.getAddress(); + const signerSafe = await getSafe({ owners: [user1.address], threshold: 1, fallbackHandler: handlerAddress }); + const signerSafeAddress = await signerSafe.getAddress(); + const safe = await getSafe({ + owners: [user1.address, user2.address, signerSafeAddress], + threshold: 2, + fallbackHandler: handlerAddress, + }); + const validator = await getExtensibleFallbackHandler(await safe.getAddress()); + const otherSafe = await getSafe({ + owners: [user1.address, user2.address, signerSafeAddress], + threshold: 2, + fallbackHandler: handlerAddress, + }); + const preconfiguredValidator = await getExtensibleFallbackHandler(await otherSafe.getAddress()); + const testVerifier = await (await hre.ethers.getContractFactory("TestSafeSignatureVerifier")).deploy(); + const testMarshalLib = await (await hre.ethers.getContractFactory("TestMarshalLib")).deploy(); + const killLib = await killLibContract(user1, hre.network.zksync); + + const mirrorSource = ` + contract Mirror { + function handle(address safe, address sender, uint256 value, bytes calldata data) external returns (bytes memory result) { + return msg.data; + } + function lookAtMe() public returns (bytes memory) { + return msg.data; + } + function nowLookAtYou(address you, string memory howYouLikeThat) public returns (bytes memory) { + return msg.data; + } + }`; + + const counterSource = ` + contract Counter { + uint256 public count = 0; + + function handle(address, address, uint256, bytes calldata) external returns (bytes memory result) { + bytes4 selector; + assembly { + selector := calldataload(164) + } + + require(selector == 0xdeadbeef, "Invalid data"); + count = count + 1; + } + }`; + + const revertVerifierSource = ` + contract RevertVerifier { + function iToHex(bytes memory buffer) public pure returns (string memory) { + // Fixed buffer size for hexadecimal convertion + bytes memory converted = new bytes(buffer.length * 2); + bytes memory _base = "0123456789abcdef"; + for (uint256 i = 0; i < buffer.length; i++) { + converted[i * 2] = _base[uint8(buffer[i]) / _base.length]; + converted[i * 2 + 1] = _base[uint8(buffer[i]) % _base.length]; + } + return string(abi.encodePacked("0x", converted)); + } + function isValidSafeSignature(address safe, address sender, bytes32 _hash, bytes32 domainSeparator, bytes32 typeHash, bytes calldata encodeData, bytes calldata payload) external view returns (bytes4) { + revert(iToHex(abi.encodePacked(msg.data))); + } + }`; + + const mirror = await deployContractFromSource(user1, mirrorSource); + const revertVerifier = await deployContractFromSource(user1, revertVerifierSource); + const counter = await deployContractFromSource(user1, counterSource); + + // Set up the mirror on the preconfigured validator + // Check the event when changing + await executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setSafeMethod", + ["0x7f8dc53c", encodeHandler(true, (await mirror.getAddress()).toLowerCase())], + [user1, user2], + ); + + const domainHash = ethers.keccak256("0xdeadbeef"); + + // setup the test verifier on the other safe + await executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setDomainVerifier", + [domainHash, await testVerifier.getAddress()], + [user1, user2], + ); + + await executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setSupportedInterface", + ["0xdeadbeef", true], + [user1, user2], + ); + + return { + safe, + validator, + otherSafe, + preconfiguredValidator, + handler, + killLib, + signLib, + signerSafe, + mirror, + counter, + testVerifier, + revertVerifier, + testMarshalLib, + }; + }); + + describe("Token Callbacks", async () => { + describe("ERC1155", async () => { + it("to handle onERC1155Received", async () => { + const { handler } = await setupTests(); + expect(await handler.onERC1155Received.staticCall(AddressZero, AddressZero, 0, 0, "0x")).to.be.eq("0xf23a6e61"); + }); + + it("to handle onERC1155BatchReceived", async () => { + const { handler } = await setupTests(); + expect(await handler.onERC1155BatchReceived.staticCall(AddressZero, AddressZero, [], [], "0x")).to.be.eq("0xbc197c81"); + }); + + it("should return true when queried for ERC1155 support", async () => { + const { handler } = await setupTests(); + expect(await handler.supportsInterface.staticCall("0x4e2312e0")).to.be.eq(true); + }); + }); + + describe("ERC721", async () => { + it("to handle onERC721Received", async () => { + const { handler } = await setupTests(); + expect(await handler.onERC721Received.staticCall(AddressZero, AddressZero, 0, "0x")).to.be.eq("0x150b7a02"); + }); + + it("should return true when queried for ERC721 support", async () => { + const { handler } = await setupTests(); + expect(await handler.supportsInterface.staticCall("0x150b7a02")).to.be.eq(true); + }); + }); + }); + + describe("Fallback Handler", async () => { + describe("fallback()", async () => { + it("should revert if call to safe is less than 4 bytes (method selector)", async () => { + const { validator } = await setupTests(); + + const tx = { + to: await validator.getAddress(), + data: "0x112233", + }; + + // Confirm method handler is not set (call should revert) + await expect(user1.call(tx)).to.be.revertedWith("invalid method selector"); + }); + }); + }); + + describe("Custom methods", async () => { + describe("setSafeMethod(bytes4,bytes32)", async () => { + it("should revert if called by non-safe", async () => { + const { handler, mirror } = await setupTests(); + await expect(handler.setSafeMethod("0xdeadbeef", encodeHandler(true, await mirror.getAddress()))).to.be.revertedWith( + "only safe can call this method", + ); + }); + + it("should emit event when setting a new method", async () => { + const { safe, handler, validator, mirror } = await setupTests(); + const safeAddress = await safe.getAddress(); + const newHandler = encodeHandler(true, await mirror.getAddress()); + await expect(executeContractCallWithSigners(safe, validator, "setSafeMethod", ["0xdededede", newHandler], [user1, user2])) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xdededede", newHandler.toLowerCase()); + + // Check that the method is actually set + expect(await handler.safeMethods.staticCall(safeAddress, "0xdededede")).to.be.eq(newHandler); + }); + + it("should emit event when updating a method", async () => { + const { otherSafe, handler, preconfiguredValidator, mirror } = await setupTests(); + const otherSafeAddress = await otherSafe.getAddress(); + const oldHandler = encodeHandler(true, await mirror.getAddress()); + const newHandler = encodeHandler(true, "0xdeAdDeADDEaDdeaDdEAddEADDEAdDeadDEADDEaD"); + await expect( + executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setSafeMethod", + ["0x7f8dc53c", newHandler], + [user1, user2], + ), + ) + .to.emit(handler, "ChangedSafeMethod") + .withArgs(otherSafeAddress, "0x7f8dc53c", oldHandler.toLowerCase(), newHandler.toLowerCase()); + + // Check that the method is actually updated + expect(await handler.safeMethods.staticCall(otherSafeAddress, "0x7f8dc53c")).to.be.eq(newHandler); + }); + + it("should emit event when removing a method", async () => { + const { otherSafe, handler, preconfiguredValidator } = await setupTests(); + const otherSafeAddress = await otherSafe.getAddress(); + await expect( + executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setSafeMethod", + ["0x7f8dc53c", HashZero], + [user1, user2], + ), + ) + .to.emit(handler, "RemovedSafeMethod") + .withArgs(otherSafeAddress, "0x7f8dc53c"); + + // Check that the method is actually removed + expect(await handler.safeMethods.staticCall(otherSafeAddress, "0x7f8dc53c")).to.be.eq(HashZero); + }); + + it("is correctly set", async () => { + const { safe, validator, mirror } = await setupTests(); + const safeAddress = await safe.getAddress(); + const tx = { + to: safeAddress, + data: mirror.interface.encodeFunctionData("lookAtMe"), + }; + + // Confirm method handler is not set (call should revert) + await expect(user1.call(tx)).to.be.reverted; + + // Setup the method handler + await executeContractCallWithSigners( + safe, + validator, + "setSafeMethod", + ["0x7f8dc53c", encodeHandler(true, await mirror.getAddress())], + [user1, user2], + ); + + // Check that the method handler is called + expect(await user1.call(tx)).to.be.eq( + "0x" + + // function selector for `handle(address,address,uint256,bytes)` + "25d6803f" + + "000000000000000000000000" + + safeAddress.slice(2).toLowerCase() + + "000000000000000000000000" + + user1.address.slice(2).toLowerCase() + + "0000000000000000000000000000000000000000000000000000000000000000" + // uint256(0) + "0000000000000000000000000000000000000000000000000000000000000080" + + "0000000000000000000000000000000000000000000000000000000000000004" + + // function selector for `lookAtMe()` + "7f8dc53c" + + "00000000000000000000000000000000000000000000000000000000", + ); + }); + + it("should allow calling non-static methods", async () => { + const { safe, validator, counter } = await setupTests(); + + const tx = { + to: await safe.getAddress(), + data: "0xdeadbeef", + }; + + // Confirm that the count is 0 + expect(await counter.count.staticCall()).to.be.eq(0); + + // Setup the method handler + await executeContractCallWithSigners( + safe, + validator, + "setSafeMethod", + ["0xdeadbeef", encodeHandler(false, await counter.getAddress())], + [user1, user2], + ); + + // Check that the method handler is called + await user1.sendTransaction(tx); + + // Check that the count is updated + expect(await counter.count.staticCall()).to.be.eq(1); + }); + }); + + describe("MarshalLib", async () => { + it("should correctly encode a handler and static flag", async () => { + const { testMarshalLib } = await setupTests(); + const handler = "0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead"; + const isStatic = true; + + const encoded = "0x000000000000000000000000deaddeaddeaddeaddeaddeaddeaddeaddeaddead"; + expect(await testMarshalLib.encode.staticCall(isStatic, handler)).to.be.eq(encoded); + expect(encoded).to.be.eq(encodeHandler(isStatic, handler)); + + const nonStaticHandler = "0xdeaddeaddeaddeaddeaddeaddeaddeaddeadbeef"; + const nonStaticResult = "0x010000000000000000000000deaddeaddeaddeaddeaddeaddeaddeaddeadbeef"; + expect(await testMarshalLib.encode.staticCall(false, nonStaticHandler)).to.be.eq(nonStaticResult); + expect(nonStaticResult).to.be.eq(encodeHandler(false, nonStaticHandler)); + }); + + it("should correctly decode a handler and static flag", async () => { + const { testMarshalLib } = await setupTests(); + + const encoded = "0x000000000000000000000000deaddeaddeaddeaddeaddeaddeaddeaddeaddead"; + expect(await testMarshalLib.decode.staticCall(encoded)).to.be.deep.eq([true, "0xdeaDDeADDEaDdeaDdEAddEADDEAdDeadDEADDEaD"]); + + expect(decodeHandler(encoded)).to.be.deep.eq([true, "0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead"]); + + const nonStaticEncoded = "0x010000000000000000000000deaddeaddeaddeaddeaddeaddeaddeaddeadbeef"; + expect(await testMarshalLib.decode.staticCall(nonStaticEncoded)).to.be.deep.eq([ + false, + "0xDEADdEAddeaDdEAdDeadDeaDDeaddeaDDEadbEeF", + ]); + + expect(decodeHandler(nonStaticEncoded)).to.be.deep.eq([false, "0xdeaddeaddeaddeaddeaddeaddeaddeaddeadbeef"]); + }); + + it("should correctly encode a handler, selector and static flag", async () => { + const { testMarshalLib } = await setupTests(); + const handler = "0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead"; + const selector = "0xdeadbeef"; + const isStatic = true; + + const encoded = "0x00deadbeef00000000000000deaddeaddeaddeaddeaddeaddeaddeaddeaddead"; + + expect(await testMarshalLib.encodeWithSelector.staticCall(isStatic, selector, handler)).to.be.eq(encoded); + }); + it("should correctly decode a handler, selector and static flag", async () => { + const { testMarshalLib } = await setupTests(); + const encoded = "0x00deadbeef00000000000000deaddeaddeaddeaddeaddeaddeaddeaddeaddead"; + + expect(await testMarshalLib.decodeWithSelector.staticCall(encoded)).to.be.deep.eq([ + true, + "0xdeadbeef", + "0xdeaDDeADDEaDdeaDdEAddEADDEAdDeadDEADDEaD", + ]); + }); + }); + }); + + describe("Signature Verifier Muxer", async () => { + describe("supportsInterface(bytes4)", async () => { + it("should return true for supporting ERC1271", async () => { + const { handler } = await setupTests(); + expect(await handler.supportsInterface.staticCall("0x1626ba7e")).to.be.eq(true); + }); + }); + + describe("setDomainVerifier(bytes32,address)", async () => { + it("should revert if called by non-safe", async () => { + const { handler, mirror } = await setupTests(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + await expect(handler.setDomainVerifier(domainSeparator, await mirror.getAddress())).to.be.revertedWith( + "only safe can call this method", + ); + }); + + it("should emit event when setting a new domain verifier", async () => { + const { safe, handler, validator, testVerifier } = await setupTests(); + const safeAddress = await safe.getAddress(); + const testVerifierAddress = await testVerifier.getAddress(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + await expect( + executeContractCallWithSigners( + safe, + validator, + "setDomainVerifier", + [domainSeparator, testVerifierAddress], + [user1, user2], + ), + ) + .to.emit(handler, "AddedDomainVerifier") + .withArgs(safeAddress, domainSeparator, testVerifierAddress); + + expect(await handler.domainVerifiers(safeAddress, domainSeparator)).to.be.eq(testVerifierAddress); + }); + + it("should emit event when updating a domain verifier", async () => { + const { otherSafe, handler, preconfiguredValidator, mirror } = await setupTests(); + const otherSafeAddress = await otherSafe.getAddress(); + const mirrorAddress = await mirror.getAddress(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + const oldVerifier = await handler.domainVerifiers(otherSafeAddress, domainSeparator); + + await expect( + await executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setDomainVerifier", + [domainSeparator, mirrorAddress], + [user1, user2], + ), + ) + .to.emit(handler, "ChangedDomainVerifier") + .withArgs(otherSafeAddress, domainSeparator, oldVerifier, mirrorAddress); + + expect(await handler.domainVerifiers(otherSafeAddress, domainSeparator)).to.be.eq(mirrorAddress); + }); + + it("should emit event when removing a domain verifier", async () => { + const { otherSafe, handler, preconfiguredValidator } = await setupTests(); + const otherSafeAddress = await otherSafe.getAddress(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + await expect( + executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setDomainVerifier", + [domainSeparator, AddressZero], + [user1, user2], + ), + ) + .to.emit(handler, "RemovedDomainVerifier") + .withArgs(otherSafeAddress, domainSeparator); + + expect(await handler.domainVerifiers(otherSafeAddress, domainSeparator)).to.be.eq(AddressZero); + }); + }); + + describe("isValidSignature(bytes32,bytes)", async () => { + it("should revert if called directly", async () => { + const { handler } = await setupTests(); + const dataHash = ethers.keccak256("0xbaddad"); + await expect(handler.isValidSignature.staticCall(dataHash, "0x")).to.be.reverted; + }); + + it("should revert if message was not signed", async () => { + const { validator } = await setupTests(); + const dataHash = ethers.keccak256("0xbaddad"); + await expect(validator.isValidSignature.staticCall(dataHash, "0x")).to.be.revertedWith("Hash not approved"); + }); + + it("should revert if signature is not valid", async () => { + const { validator } = await setupTests(); + const dataHash = ethers.keccak256("0xbaddad"); + await expect(validator.isValidSignature.staticCall(dataHash, "0xdeaddeaddeaddead")).to.be.reverted; + }); + + it("should revert through default flow if signature is short", async () => { + const { validator } = await setupTests(); + const dataHash = ethers.keccak256("0xbaddad"); + await expect(validator.isValidSignature.staticCall(dataHash, "0x5fd7e97ddead")).to.be.revertedWith("GS020"); + }); + + it("should return magic value if message was signed", async () => { + const { safe, validator, signLib } = await setupTests(); + const dataHash = ethers.keccak256("0xbaddad"); + await executeContractCallWithSigners(safe, signLib, "signMessage", [dataHash], [user1, user2], true); + expect(await validator.isValidSignature.staticCall(dataHash, "0x")).to.be.eq("0x1626ba7e"); + }); + + it("should return magic value if enough owners signed with typed signatures", async () => { + const { validator } = await setupTests(); + const validatorAddress = await validator.getAddress(); + const dataHash = ethers.keccak256("0xbaddad"); + const typedDataSig = { + signer: user1.address, + data: await user1.signTypedData( + { verifyingContract: validatorAddress, chainId: await chainId() }, + EIP712_SAFE_MESSAGE_TYPE, + { message: dataHash }, + ), + }; + const typedDataSig2 = { + signer: user2.address, + data: await user2.signTypedData( + { verifyingContract: validatorAddress, chainId: await chainId() }, + EIP712_SAFE_MESSAGE_TYPE, + { message: dataHash }, + ), + }; + + expect(await validator.isValidSignature.staticCall(dataHash, buildSignatureBytes([typedDataSig, typedDataSig2]))).to.be.eq( + "0x1626ba7e", + ); + }); + + it("should send EIP-712 context to custom verifier", async () => { + const { safe, validator, revertVerifier } = await setupTests(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + const typeHash = ethers.keccak256("0xbaddad"); + // abi encode the message + const encodeData = ethers.solidityPacked( + ["bytes32", "bytes32"], + [ + ethers.keccak256("0xbaddadbaddadbaddadbaddadbaddadbaddad"), + ethers.keccak256("0xdeadbeefdeadbeefdeadbeefdeadbeefdead"), + ], + ); + + // set the revert verifier for the domain separator + await executeContractCallWithSigners( + safe, + validator, + "setDomainVerifier", + [domainSeparator, await revertVerifier.getAddress()], + [user1, user2], + ); + + const [dataHash, encodedMessage] = encodeCustomVerifier(encodeData, domainSeparator, typeHash, "0xdeadbeef"); + + // Test with a domain verifier - should revert with `GS021` + await expect(validator.isValidSignature.staticCall(dataHash, encodedMessage)).to.be.revertedWith( + "0x" + + // function call for isValidSafeSignature + "53f00b14" + + "000000000000000000000000" + + (await safe.getAddress()).slice(2).toLowerCase() + + "000000000000000000000000" + + user1.address.slice(2).toLowerCase() + + dataHash.slice(2) + + domainSeparator.slice(2) + + typeHash.slice(2) + + "00000000000000000000000000000000000000000000000000000000000000e0" + + "0000000000000000000000000000000000000000000000000000000000000140" + + hre.ethers.AbiCoder.defaultAbiCoder().encode(["bytes"], [encodeData]).slice(66) + + "0000000000000000000000000000000000000000000000000000000000000004" + + "deadbeef00000000000000000000000000000000000000000000000000000000", + ); + }); + + it("should revert it trying to forge the domain separator", async () => { + const { preconfiguredValidator } = await setupTests(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + const forgedDomainSeparator = ethers.keccak256("0xdeadbeefdeadbeef"); + const typeHash = ethers.keccak256("0xbaddad"); + // abi encode the message + const encodeData = ethers.solidityPacked( + ["bytes32", "bytes32"], + [ + ethers.keccak256("0xbaddadbaddadbaddadbaddadbaddadbaddad"), + ethers.keccak256("0xdeadbeefdeadbeefdeadbeefdeadbeefdead"), + ], + ); + + // calculate the hash of the message + const dataHash = ethers.keccak256( + ethers.solidityPacked( + ["bytes1", "bytes1", "bytes32", "bytes32"], + [ + "0x19", + "0x01", + forgedDomainSeparator, + ethers.keccak256(ethers.solidityPacked(["bytes32", "bytes"], [typeHash, encodeData])), + ], + ), + ); + + // create the function fragment for the `safeSignature(bytes32,bytes32,bytes,bytes)` function + const safeSignatureFragment = new ethers.Interface([`function safeSignature(bytes32,bytes32,bytes,bytes)`]); + const encodedMessage = safeSignatureFragment.encodeFunctionData("safeSignature(bytes32,bytes32,bytes,bytes)", [ + domainSeparator, + typeHash, + encodeData, + "0x", + ]); + + // Test with a domain verifier - should return magic value + await expect(preconfiguredValidator.isValidSignature.staticCall(dataHash, encodedMessage)).to.be.revertedWith("GS026"); + }); + + it("should return magic value if signed by a domain verifier", async () => { + const { validator, preconfiguredValidator } = await setupTests(); + const domainSeparator = ethers.keccak256("0xdeadbeef"); + const typeHash = ethers.keccak256("0xbaddad"); + // abi encode the message + const encodeData = hre.ethers.AbiCoder.defaultAbiCoder().encode( + ["bytes32"], + [ethers.keccak256("0xbaddadbaddadbaddadbaddadbaddadbaddad")], + ); + + const [dataHash, encodedMessage] = encodeCustomVerifier(encodeData, domainSeparator, typeHash, "0x"); + + // Test without a domain verifier - should revert with `GS026` + await expect(validator.isValidSignature.staticCall(dataHash, encodedMessage)).to.be.revertedWith("GS026"); + + // Test with a domain verifier - should return magic value + expect(await preconfiguredValidator.isValidSignature.staticCall(dataHash, encodedMessage)).to.be.eq("0x1626ba7e"); + }); + }); + }); + + describe("IERC165", async () => { + describe("supportsInterface(bytes4)", async () => { + it("should return true for ERC165", async () => { + const { validator } = await setupTests(); + expect(await validator.supportsInterface.staticCall("0x01ffc9a7")).to.be.true; + }); + }); + + describe("setSupportedInterface(bytes4,bool)", async () => { + it("should revert if called by non-safe", async () => { + const { handler } = await setupTests(); + await expect(handler.setSupportedInterface("0xdeadbeef", true)).to.be.revertedWith("only safe can call this method"); + }); + + it("should revert if trying to set an invalid interface", async () => { + const { validator, safe } = await setupTests(); + await expect( + executeContractCallWithSigners(safe, validator, "setSupportedInterface", ["0xffffffff", true], [user1, user2]), + ).to.be.revertedWith("invalid interface id"); + }); + + it("should emit event when adding a newly supported interface", async () => { + const { validator, safe, handler } = await setupTests(); + await expect(executeContractCallWithSigners(safe, validator, "setSupportedInterface", ["0xdeadbeef", true], [user1, user2])) + .to.emit(handler, "AddedInterface") + .withArgs(await safe.getAddress(), "0xdeadbeef"); + }); + + it("should emit event when removing a supported interface", async () => { + const { handler, otherSafe, preconfiguredValidator } = await setupTests(); + + await expect( + executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setSupportedInterface", + ["0xdeadbeef", false], + [user1, user2], + ), + ) + .to.emit(handler, "RemovedInterface") + .withArgs(await otherSafe.getAddress(), "0xdeadbeef"); + }); + + it("should not emit event when removing an unsupported interface", async () => { + const { handler, otherSafe, preconfiguredValidator } = await setupTests(); + + await expect( + executeContractCallWithSigners( + otherSafe, + preconfiguredValidator, + "setSupportedInterface", + ["0xbeafdead", false], + [user1, user2], + ), + ).to.not.emit(handler, "RemovedInterface"); + }); + }); + + describe("addSupportedInterfaceBatch(bytes4, bytes32[]", async () => { + it("should revert if called by non-safe", async () => { + const { handler } = await setupTests(); + await expect(handler.addSupportedInterfaceBatch("0xdeadbeef", [HashZero])).to.be.revertedWith( + "only safe can call this method", + ); + }); + + it("should revert if batch contains an invalid interface", async () => { + const { validator, safe } = await setupTests(); + await expect( + executeContractCallWithSigners( + safe, + validator, + "addSupportedInterfaceBatch", + ["0xffffffff", [HashZero]], + [user1, user2], + ), + ).to.be.revertedWith("interface id mismatch"); + }); + + it("should add all handlers in batch", async () => { + const { validator, safe, handler, mirror } = await setupTests(); + const safeAddress = await safe.getAddress(); + + // calculate the selector for each function + const selector1 = "0xabababab"; + const selector2 = "0xcdcdcdcd"; + const selector3 = "0xefefefef"; + + // calculate the interface id which is the xor of all selectors + const interfaceId = ethers.hexlify(ethers.toBeHex(BigInt(selector1) ^ BigInt(selector2) ^ BigInt(selector3))); + + // create the batch + const mirrorAddress = await mirror.getAddress(); + const batch = [selector1, selector2, selector3].map((selector) => encodeHandlerFunction(true, selector, mirrorAddress)); + + await expect( + executeContractCallWithSigners(safe, validator, "addSupportedInterfaceBatch", [interfaceId, batch], [user1, user2]), + ) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xabababab", encodeHandler(true, mirrorAddress)) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xcdcdcdcd", encodeHandler(true, mirrorAddress)) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xefefefef", encodeHandler(true, mirrorAddress)) + .to.emit(handler, "AddedInterface") + .withArgs(safeAddress, interfaceId); + + // check that the interface is supported + expect(await validator.supportsInterface(interfaceId)).to.be.true; + }); + }); + + describe("removeSupportedInterfaceBatch(bytes4, bytes4[]", async () => { + it("should revert if called by non-safe", async () => { + const { handler } = await setupTests(); + await expect(handler.removeSupportedInterfaceBatch("0xdeadbeef", ["0xdeadbeef"])).to.be.revertedWith( + "only safe can call this method", + ); + }); + + it("should remove all methods in a batch", async () => { + const { validator, safe, handler, mirror } = await setupTests(); + const safeAddress = await safe.getAddress(); + + // calculate the selector for each function + const selector1 = "0xabababab"; + const selector2 = "0xcdcdcdcd"; + const selector3 = "0xefefefef"; + + // calculate the interface id which is the xor of all selectors + const interfaceId = ethers.hexlify(ethers.toBeHex(BigInt(selector1) ^ BigInt(selector2) ^ BigInt(selector3))); + + // create the batch + const mirrorAddress = await mirror.getAddress(); + const batch = [selector1, selector2, selector3].map((selector) => encodeHandlerFunction(true, selector, mirrorAddress)); + + await expect( + executeContractCallWithSigners(safe, validator, "addSupportedInterfaceBatch", [interfaceId, batch], [user1, user2]), + ) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xabababab", encodeHandler(true, mirrorAddress)) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xcdcdcdcd", encodeHandler(true, mirrorAddress)) + .to.emit(handler, "AddedSafeMethod") + .withArgs(safeAddress, "0xefefefef", encodeHandler(true, mirrorAddress)) + .to.emit(handler, "AddedInterface") + .withArgs(safeAddress, interfaceId); + + // check that the interface is supported + expect(await validator.supportsInterface(interfaceId)).to.be.true; + + // remove the interface with the incorrect interfaceId + await expect( + executeContractCallWithSigners( + safe, + validator, + "removeSupportedInterfaceBatch", + ["0xdeadbeef", [selector1, selector2, selector3]], + [user1, user2], + ), + ).to.be.revertedWith("interface id mismatch"); + + // remove the interface + await expect( + executeContractCallWithSigners( + safe, + validator, + "removeSupportedInterfaceBatch", + [interfaceId, [selector1, selector2, selector3]], + [user1, user2], + ), + ) + .to.emit(handler, "RemovedSafeMethod") + .withArgs(safeAddress, "0xabababab") + .to.emit(handler, "RemovedSafeMethod") + .withArgs(safeAddress, "0xcdcdcdcd") + .to.emit(handler, "RemovedSafeMethod") + .withArgs(safeAddress, "0xefefefef") + .to.emit(handler, "RemovedInterface") + .withArgs(safeAddress, interfaceId); + + // check that the interface is no longer supported + expect(await validator.supportsInterface(interfaceId)).to.be.false; + }); + }); + }); +}); diff --git a/test/handlers/HandlerContext.spec.ts b/test/handlers/HandlerContext.spec.ts index 5c7c89af2..31c72bcd8 100644 --- a/test/handlers/HandlerContext.spec.ts +++ b/test/handlers/HandlerContext.spec.ts @@ -47,4 +47,19 @@ describe("HandlerContext", () => { expect(handler.interface.decodeFunctionResult("dudududu", response)).to.be.deep.eq([user1.address, safeAddress]); }); + + it("reverts if calldata is less than 20 bytes", async () => { + const { + handler, + signers: [user1], + } = await setup(); + + const handlerAddress = await handler.getAddress(); + await expect( + user1.call({ + to: handlerAddress, + data: handler.interface.encodeFunctionData("dudududu"), + }), + ).to.be.revertedWith("Invalid calldata length"); + }); }); diff --git a/test/utils/extensible.ts b/test/utils/extensible.ts new file mode 100644 index 000000000..65c8c500f --- /dev/null +++ b/test/utils/extensible.ts @@ -0,0 +1,69 @@ +import { ethers } from "ethers"; + +// Given whether the handler is static or not, and the handler address, return the encoded bytes +// The encoded handler is a bytes32, so we need to encode the handler address and the isStatic flag +// into a single bytes32. +// The first 1 byte is the isStatic flag, and the remaining 31 bytes are the handler address, +// zero left padded. +export const encodeHandler = (isStatic: boolean, handler: string): string => { + const isStaticBytes = ethers.hexlify(isStatic ? "0x00" : "0x01"); + const handlerBytes = ethers.zeroPadValue(handler, 31); + return ethers.hexlify(ethers.concat([isStaticBytes, handlerBytes])); +}; + +// Given the encoded handler, return the isStatic flag and the handler address. +// The handler address has been zero left padded, so we need to remove the padding. +export const decodeHandler = (encodedHandler: string): [boolean, string] => { + const isStatic = ethers.dataSlice(encodedHandler, 0, 1) === "0x00"; + const handler = ethers.dataSlice(encodedHandler, 12); + return [isStatic, handler]; +}; + +// Given: +// - whether the handler is static or not +// - the 4byte selector of the function to call +// - the handler address +// Encode all into a single bytes32. +// The first 1 byte is the isStatic flag, the next 4 bytes are the selector, and the remaining 27 bytes are the handler address, +// zero left padded. +export const encodeHandlerFunction = (isStatic: boolean, selector: string, handler: string): string => { + const isStaticBytes = ethers.hexlify(isStatic ? "0x00" : "0x01"); + const selectorBytes = ethers.hexlify(selector); + const handlerBytes = ethers.zeroPadValue(handler, 27); + return ethers.hexlify(ethers.concat([isStaticBytes, selectorBytes, handlerBytes])); +}; + +// Given the encoded handler function, return the isStatic flag, the selector, and the handler address. +// The handler address has been zero left padded, so we need to remove the padding. +export const decodeHandlerFunction = (encodedHandlerFunction: string): [boolean, string, string] => { + const isStatic = ethers.dataSlice(encodedHandlerFunction, 0, 1) === "0x00"; + const selector = ethers.dataSlice(encodedHandlerFunction, 1, 5); + const handler = ethers.dataSlice(encodedHandlerFunction, 12); + return [isStatic, selector, handler]; +}; + +export const encodeCustomVerifier = ( + encodeData: string, + domainSeparator: string, + typeHash: string, + signature: string, +): [string, string] => { + // calculate the hash of the message + const dataHash = ethers.keccak256( + ethers.solidityPacked( + ["bytes1", "bytes1", "bytes32", "bytes32"], + ["0x19", "0x01", domainSeparator, ethers.keccak256(ethers.solidityPacked(["bytes32", "bytes"], [typeHash, encodeData]))], + ), + ); + + // create the function fragment for the `safeSignature(bytes)` function + const safeSignatureFragment = new ethers.Interface([`function safeSignature(bytes32,bytes32,bytes,bytes)`]); + const encodedMessage = safeSignatureFragment.encodeFunctionData("safeSignature(bytes32,bytes32,bytes,bytes)", [ + domainSeparator, + typeHash, + encodeData, + signature, + ]); + + return [dataHash, encodedMessage]; +}; diff --git a/test/utils/setup.ts b/test/utils/setup.ts index 699d0204b..6eb4039f1 100644 --- a/test/utils/setup.ts +++ b/test/utils/setup.ts @@ -156,6 +156,17 @@ export const getCompatFallbackHandler = async (address?: string) => { return fallbackHandler; }; +export const getExtensibleFallbackHandler = async (address?: string) => { + if (!address) { + const extensibleFallbackHandlerAddress = await deployments.get("ExtensibleFallbackHandler"); + address = extensibleFallbackHandlerAddress.address; + } + + const extensibleFallbackHandler = await hre.ethers.getContractAt("ExtensibleFallbackHandler", address); + + return extensibleFallbackHandler; +}; + export const getSafeProxyRuntimeCode = async (): Promise => { const proxyArtifact = await hre.artifacts.readArtifact("SafeProxy"); From 99481e486ede6cbda51485360e6d7e618fc0d17e Mon Sep 17 00:00:00 2001 From: Shebin John Date: Tue, 29 Oct 2024 10:09:49 +0100 Subject: [PATCH 2/3] Disable one contract per file --- contracts/handler/extensible/SignatureVerifierMuxer.sol | 1 + 1 file changed, 1 insertion(+) diff --git a/contracts/handler/extensible/SignatureVerifierMuxer.sol b/contracts/handler/extensible/SignatureVerifierMuxer.sol index e5391ad1a..692f8dec2 100644 --- a/contracts/handler/extensible/SignatureVerifierMuxer.sol +++ b/contracts/handler/extensible/SignatureVerifierMuxer.sol @@ -1,4 +1,5 @@ // SPDX-License-Identifier: LGPL-3.0-only +// solhint-disable one-contract-per-file pragma solidity >=0.7.0 <0.9.0; import {Safe, ExtensibleBase} from "./ExtensibleBase.sol"; From b69c2a0ece852aadc8e3fc8a1dc9947def6209cf Mon Sep 17 00:00:00 2001 From: Shebin John Date: Tue, 29 Oct 2024 10:10:09 +0100 Subject: [PATCH 3/3] Natspec for MarshaLib --- contracts/handler/extensible/MarshalLib.sol | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/contracts/handler/extensible/MarshalLib.sol b/contracts/handler/extensible/MarshalLib.sol index 161aeeaac..ca33cf678 100644 --- a/contracts/handler/extensible/MarshalLib.sol +++ b/contracts/handler/extensible/MarshalLib.sol @@ -13,6 +13,15 @@ library MarshalLib { data = bytes32(uint256(uint160(handler)) | (isStatic ? 0 : (1 << 248))); } + /** + * Encode a method handler into a `bytes32` value with a selector + * @dev The first byte of the `bytes32` value is set to 0x01 if the method is not static (`view`) + * @dev The next 4 bytes of the `bytes32` value are set to the selector of the method + * @dev The last 20 bytes of the `bytes32` value are set to the address of the handler contract + * @param isStatic Whether the method is static (`view`) or not + * @param selector The selector of the method + * @param handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface + */ function encodeWithSelector(bool isStatic, bytes4 selector, address handler) internal pure returns (bytes32 data) { data = bytes32(uint256(uint160(handler)) | (isStatic ? 0 : (1 << 248)) | (uint256(uint32(selector)) << 216)); } @@ -32,6 +41,13 @@ library MarshalLib { } } + /** + * Given a `bytes32` value, decode it into a method handler and return it + * @param data The packed data to decode + * @return isStatic Whether the method is static (`view`) or not + * @return selector The selector of the method + * @return handler The address of the handler contract implementing the `IFallbackMethod` or `IStaticFallbackMethod` interface + */ function decodeWithSelector(bytes32 data) internal pure returns (bool isStatic, bytes4 selector, address handler) { // solhint-disable-next-line no-inline-assembly assembly {