diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py new file mode 100644 index 0000000000..951bf4702c --- /dev/null +++ b/slither/utils/code_generation.py @@ -0,0 +1,104 @@ +# Functions for generating Solidity code +from typing import TYPE_CHECKING, Optional + +from slither.utils.type import convert_type_for_solidity_signature_to_string + +if TYPE_CHECKING: + from slither.core.declarations import FunctionContract, Structure, Contract + + +def generate_interface(contract: "Contract") -> str: + """ + Generates code for a Solidity interface to the contract. + Args: + contract: A Contract object + + Returns: + A string with the code for an interface, with function stubs for all public or external functions and + state variables, as well as any events, custom errors and/or structs declared in the contract. + """ + interface = f"interface I{contract.name} {{\n" + for event in contract.events: + name, args = event.signature + interface += f" event {name}({', '.join(args)});\n" + for error in contract.custom_errors: + args = [ + convert_type_for_solidity_signature_to_string(arg.type) + .replace("(", "") + .replace(")", "") + for arg in error.parameters + ] + interface += f" error {error.name}({', '.join(args)});\n" + for enum in contract.enums: + interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" + for struct in contract.structures: + interface += generate_struct_interface_str(struct) + for var in contract.state_variables_entry_points: + interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n" + for func in contract.functions_entry_points: + if func.is_constructor or func.is_fallback or func.is_receive: + continue + interface += f" function {generate_interface_function_signature(func)};\n" + interface += "}\n\n" + return interface + + +def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]: + """ + Generates a string of the form: + func_name(type1,type2) external {payable/view/pure} returns (type3) + + Args: + func: A FunctionContract object + + Returns: + The function interface as a str (contains the return values). + Returns None if the function is private or internal, or is a constructor/fallback/receive. + """ + + name, parameters, return_vars = func.signature + if ( + func not in func.contract.functions_entry_points + or func.is_constructor + or func.is_fallback + or func.is_receive + ): + return None + view = " view" if func.view else "" + pure = " pure" if func.pure else "" + payable = " payable" if func.payable else "" + returns = [ + convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "") + for ret in func.returns + ] + parameters = [ + convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "") + for param in func.parameters + ] + _interface_signature_str = ( + name + "(" + ",".join(parameters) + ") external" + payable + pure + view + ) + if len(return_vars) > 0: + _interface_signature_str += " returns (" + ",".join(returns) + ")" + return _interface_signature_str + + +def generate_struct_interface_str(struct: "Structure") -> str: + """ + Generates code for a structure declaration in an interface of the form: + struct struct_name { + elem1_type elem1_name; + elem2_type elem2_name; + ... ... + } + Args: + struct: A Structure object + + Returns: + The structure declaration code as a string. + """ + definition = f" struct {struct.name} {{\n" + for elem in struct.elems_ordered: + definition += f" {elem.type} {elem.name};\n" + definition += " }\n" + return definition diff --git a/tests/code_generation/CodeGeneration.sol b/tests/code_generation/CodeGeneration.sol new file mode 100644 index 0000000000..c15017abd5 --- /dev/null +++ b/tests/code_generation/CodeGeneration.sol @@ -0,0 +1,56 @@ +pragma solidity ^0.8.4; +interface I { + enum SomeEnum { ONE, TWO, THREE } + error ErrorWithEnum(SomeEnum e); +} + +contract TestContract is I { + uint public stateA; + uint private stateB; + address public immutable owner = msg.sender; + mapping(address => mapping(uint => St)) public structs; + + event NoParams(); + event Anonymous() anonymous; + event OneParam(address addr); + event OneParamIndexed(address indexed addr); + + error ErrorSimple(); + error ErrorWithArgs(uint, uint); + error ErrorWithStruct(St s); + + struct St{ + uint v; + } + + function err0() public { + revert ErrorSimple(); + } + function err1() public { + St memory s; + revert ErrorWithStruct(s); + } + function err2(uint a, uint b) public { + revert ErrorWithArgs(a, b); + revert ErrorWithArgs(uint(SomeEnum.ONE), uint(SomeEnum.ONE)); + } + function err3() internal { + revert('test'); + } + function err4() private { + revert ErrorWithEnum(SomeEnum.ONE); + } + + function newSt(uint x) public returns (St memory) { + St memory st; + st.v = x; + structs[msg.sender][x] = st; + return st; + } + function getSt(uint x) public view returns (St memory) { + return structs[msg.sender][x]; + } + function removeSt(St memory st) public { + delete structs[msg.sender][st.v]; + } +} \ No newline at end of file diff --git a/tests/code_generation/TEST_generated_code.sol b/tests/code_generation/TEST_generated_code.sol new file mode 100644 index 0000000000..62e08bd74c --- /dev/null +++ b/tests/code_generation/TEST_generated_code.sol @@ -0,0 +1,24 @@ +interface ITestContract { + event NoParams(); + event Anonymous(); + event OneParam(address); + event OneParamIndexed(address); + error ErrorWithEnum(uint8); + error ErrorSimple(); + error ErrorWithArgs(uint256, uint256); + error ErrorWithStruct(uint256); + enum SomeEnum { ONE, TWO, THREE } + struct St { + uint256 v; + } + function stateA() external returns (uint256); + function owner() external returns (address); + function structs(address,uint256) external returns (uint256); + function err0() external; + function err1() external; + function err2(uint256,uint256) external; + function newSt(uint256) external returns (uint256); + function getSt(uint256) external view returns (uint256); + function removeSt(uint256) external; +} + diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py new file mode 100644 index 0000000000..13d1c8fb0f --- /dev/null +++ b/tests/test_code_generation.py @@ -0,0 +1,25 @@ +import os + +from solc_select import solc_select + +from slither import Slither +from slither.utils.code_generation import ( + generate_interface, +) + +SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +CODE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "code_generation") + + +def test_interface_generation() -> None: + solc_select.switch_global_version("0.8.4", always_install=True) + + sl = Slither(os.path.join(CODE_TEST_ROOT, "CodeGeneration.sol")) + + actual = generate_interface(sl.get_contract_from_name("TestContract")[0]) + expected_path = os.path.join(CODE_TEST_ROOT, "TEST_generated_code.sol") + + with open(expected_path, "r", encoding="utf-8") as file: + expected = file.read() + + assert actual == expected