-
Notifications
You must be signed in to change notification settings - Fork 996
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1730 from webthethird/dev-generate-interface-code
Generate interface code in new `slither.utils.code_generation`
- Loading branch information
Showing
4 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Functions for generating Solidity code | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from slither.utils.type import convert_type_for_solidity_signature_to_string | ||
|
||
if TYPE_CHECKING: | ||
from slither.core.declarations import FunctionContract, Structure, Contract | ||
|
||
|
||
def generate_interface(contract: "Contract") -> str: | ||
""" | ||
Generates code for a Solidity interface to the contract. | ||
Args: | ||
contract: A Contract object | ||
Returns: | ||
A string with the code for an interface, with function stubs for all public or external functions and | ||
state variables, as well as any events, custom errors and/or structs declared in the contract. | ||
""" | ||
interface = f"interface I{contract.name} {{\n" | ||
for event in contract.events: | ||
name, args = event.signature | ||
interface += f" event {name}({', '.join(args)});\n" | ||
for error in contract.custom_errors: | ||
args = [ | ||
convert_type_for_solidity_signature_to_string(arg.type) | ||
.replace("(", "") | ||
.replace(")", "") | ||
for arg in error.parameters | ||
] | ||
interface += f" error {error.name}({', '.join(args)});\n" | ||
for enum in contract.enums: | ||
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" | ||
for struct in contract.structures: | ||
interface += generate_struct_interface_str(struct) | ||
for var in contract.state_variables_entry_points: | ||
interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n" | ||
for func in contract.functions_entry_points: | ||
if func.is_constructor or func.is_fallback or func.is_receive: | ||
continue | ||
interface += f" function {generate_interface_function_signature(func)};\n" | ||
interface += "}\n\n" | ||
return interface | ||
|
||
|
||
def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]: | ||
""" | ||
Generates a string of the form: | ||
func_name(type1,type2) external {payable/view/pure} returns (type3) | ||
Args: | ||
func: A FunctionContract object | ||
Returns: | ||
The function interface as a str (contains the return values). | ||
Returns None if the function is private or internal, or is a constructor/fallback/receive. | ||
""" | ||
|
||
name, parameters, return_vars = func.signature | ||
if ( | ||
func not in func.contract.functions_entry_points | ||
or func.is_constructor | ||
or func.is_fallback | ||
or func.is_receive | ||
): | ||
return None | ||
view = " view" if func.view else "" | ||
pure = " pure" if func.pure else "" | ||
payable = " payable" if func.payable else "" | ||
returns = [ | ||
convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "") | ||
for ret in func.returns | ||
] | ||
parameters = [ | ||
convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "") | ||
for param in func.parameters | ||
] | ||
_interface_signature_str = ( | ||
name + "(" + ",".join(parameters) + ") external" + payable + pure + view | ||
) | ||
if len(return_vars) > 0: | ||
_interface_signature_str += " returns (" + ",".join(returns) + ")" | ||
return _interface_signature_str | ||
|
||
|
||
def generate_struct_interface_str(struct: "Structure") -> str: | ||
""" | ||
Generates code for a structure declaration in an interface of the form: | ||
struct struct_name { | ||
elem1_type elem1_name; | ||
elem2_type elem2_name; | ||
... ... | ||
} | ||
Args: | ||
struct: A Structure object | ||
Returns: | ||
The structure declaration code as a string. | ||
""" | ||
definition = f" struct {struct.name} {{\n" | ||
for elem in struct.elems_ordered: | ||
definition += f" {elem.type} {elem.name};\n" | ||
definition += " }\n" | ||
return definition |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
pragma solidity ^0.8.4; | ||
interface I { | ||
enum SomeEnum { ONE, TWO, THREE } | ||
error ErrorWithEnum(SomeEnum e); | ||
} | ||
|
||
contract TestContract is I { | ||
uint public stateA; | ||
uint private stateB; | ||
address public immutable owner = msg.sender; | ||
mapping(address => mapping(uint => St)) public structs; | ||
|
||
event NoParams(); | ||
event Anonymous() anonymous; | ||
event OneParam(address addr); | ||
event OneParamIndexed(address indexed addr); | ||
|
||
error ErrorSimple(); | ||
error ErrorWithArgs(uint, uint); | ||
error ErrorWithStruct(St s); | ||
|
||
struct St{ | ||
uint v; | ||
} | ||
|
||
function err0() public { | ||
revert ErrorSimple(); | ||
} | ||
function err1() public { | ||
St memory s; | ||
revert ErrorWithStruct(s); | ||
} | ||
function err2(uint a, uint b) public { | ||
revert ErrorWithArgs(a, b); | ||
revert ErrorWithArgs(uint(SomeEnum.ONE), uint(SomeEnum.ONE)); | ||
} | ||
function err3() internal { | ||
revert('test'); | ||
} | ||
function err4() private { | ||
revert ErrorWithEnum(SomeEnum.ONE); | ||
} | ||
|
||
function newSt(uint x) public returns (St memory) { | ||
St memory st; | ||
st.v = x; | ||
structs[msg.sender][x] = st; | ||
return st; | ||
} | ||
function getSt(uint x) public view returns (St memory) { | ||
return structs[msg.sender][x]; | ||
} | ||
function removeSt(St memory st) public { | ||
delete structs[msg.sender][st.v]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
interface ITestContract { | ||
event NoParams(); | ||
event Anonymous(); | ||
event OneParam(address); | ||
event OneParamIndexed(address); | ||
error ErrorWithEnum(uint8); | ||
error ErrorSimple(); | ||
error ErrorWithArgs(uint256, uint256); | ||
error ErrorWithStruct(uint256); | ||
enum SomeEnum { ONE, TWO, THREE } | ||
struct St { | ||
uint256 v; | ||
} | ||
function stateA() external returns (uint256); | ||
function owner() external returns (address); | ||
function structs(address,uint256) external returns (uint256); | ||
function err0() external; | ||
function err1() external; | ||
function err2(uint256,uint256) external; | ||
function newSt(uint256) external returns (uint256); | ||
function getSt(uint256) external view returns (uint256); | ||
function removeSt(uint256) external; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
|
||
from solc_select import solc_select | ||
|
||
from slither import Slither | ||
from slither.utils.code_generation import ( | ||
generate_interface, | ||
) | ||
|
||
SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
CODE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "code_generation") | ||
|
||
|
||
def test_interface_generation() -> None: | ||
solc_select.switch_global_version("0.8.4", always_install=True) | ||
|
||
sl = Slither(os.path.join(CODE_TEST_ROOT, "CodeGeneration.sol")) | ||
|
||
actual = generate_interface(sl.get_contract_from_name("TestContract")[0]) | ||
expected_path = os.path.join(CODE_TEST_ROOT, "TEST_generated_code.sol") | ||
|
||
with open(expected_path, "r", encoding="utf-8") as file: | ||
expected = file.read() | ||
|
||
assert actual == expected |