Skip to content

Commit

Permalink
Merge pull request #1730 from webthethird/dev-generate-interface-code
Browse files Browse the repository at this point in the history
Generate interface code in new `slither.utils.code_generation`
  • Loading branch information
montyly authored Mar 22, 2023
2 parents 346d3b6 + 6d68571 commit 3c55228
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 0 deletions.
104 changes: 104 additions & 0 deletions slither/utils/code_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Functions for generating Solidity code
from typing import TYPE_CHECKING, Optional

from slither.utils.type import convert_type_for_solidity_signature_to_string

if TYPE_CHECKING:
from slither.core.declarations import FunctionContract, Structure, Contract


def generate_interface(contract: "Contract") -> str:
"""
Generates code for a Solidity interface to the contract.
Args:
contract: A Contract object
Returns:
A string with the code for an interface, with function stubs for all public or external functions and
state variables, as well as any events, custom errors and/or structs declared in the contract.
"""
interface = f"interface I{contract.name} {{\n"
for event in contract.events:
name, args = event.signature
interface += f" event {name}({', '.join(args)});\n"
for error in contract.custom_errors:
args = [
convert_type_for_solidity_signature_to_string(arg.type)
.replace("(", "")
.replace(")", "")
for arg in error.parameters
]
interface += f" error {error.name}({', '.join(args)});\n"
for enum in contract.enums:
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n"
for struct in contract.structures:
interface += generate_struct_interface_str(struct)
for var in contract.state_variables_entry_points:
interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n"
for func in contract.functions_entry_points:
if func.is_constructor or func.is_fallback or func.is_receive:
continue
interface += f" function {generate_interface_function_signature(func)};\n"
interface += "}\n\n"
return interface


def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]:
"""
Generates a string of the form:
func_name(type1,type2) external {payable/view/pure} returns (type3)
Args:
func: A FunctionContract object
Returns:
The function interface as a str (contains the return values).
Returns None if the function is private or internal, or is a constructor/fallback/receive.
"""

name, parameters, return_vars = func.signature
if (
func not in func.contract.functions_entry_points
or func.is_constructor
or func.is_fallback
or func.is_receive
):
return None
view = " view" if func.view else ""
pure = " pure" if func.pure else ""
payable = " payable" if func.payable else ""
returns = [
convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "")
for ret in func.returns
]
parameters = [
convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "")
for param in func.parameters
]
_interface_signature_str = (
name + "(" + ",".join(parameters) + ") external" + payable + pure + view
)
if len(return_vars) > 0:
_interface_signature_str += " returns (" + ",".join(returns) + ")"
return _interface_signature_str


def generate_struct_interface_str(struct: "Structure") -> str:
"""
Generates code for a structure declaration in an interface of the form:
struct struct_name {
elem1_type elem1_name;
elem2_type elem2_name;
... ...
}
Args:
struct: A Structure object
Returns:
The structure declaration code as a string.
"""
definition = f" struct {struct.name} {{\n"
for elem in struct.elems_ordered:
definition += f" {elem.type} {elem.name};\n"
definition += " }\n"
return definition
56 changes: 56 additions & 0 deletions tests/code_generation/CodeGeneration.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
pragma solidity ^0.8.4;
interface I {
enum SomeEnum { ONE, TWO, THREE }
error ErrorWithEnum(SomeEnum e);
}

contract TestContract is I {
uint public stateA;
uint private stateB;
address public immutable owner = msg.sender;
mapping(address => mapping(uint => St)) public structs;

event NoParams();
event Anonymous() anonymous;
event OneParam(address addr);
event OneParamIndexed(address indexed addr);

error ErrorSimple();
error ErrorWithArgs(uint, uint);
error ErrorWithStruct(St s);

struct St{
uint v;
}

function err0() public {
revert ErrorSimple();
}
function err1() public {
St memory s;
revert ErrorWithStruct(s);
}
function err2(uint a, uint b) public {
revert ErrorWithArgs(a, b);
revert ErrorWithArgs(uint(SomeEnum.ONE), uint(SomeEnum.ONE));
}
function err3() internal {
revert('test');
}
function err4() private {
revert ErrorWithEnum(SomeEnum.ONE);
}

function newSt(uint x) public returns (St memory) {
St memory st;
st.v = x;
structs[msg.sender][x] = st;
return st;
}
function getSt(uint x) public view returns (St memory) {
return structs[msg.sender][x];
}
function removeSt(St memory st) public {
delete structs[msg.sender][st.v];
}
}
24 changes: 24 additions & 0 deletions tests/code_generation/TEST_generated_code.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
interface ITestContract {
event NoParams();
event Anonymous();
event OneParam(address);
event OneParamIndexed(address);
error ErrorWithEnum(uint8);
error ErrorSimple();
error ErrorWithArgs(uint256, uint256);
error ErrorWithStruct(uint256);
enum SomeEnum { ONE, TWO, THREE }
struct St {
uint256 v;
}
function stateA() external returns (uint256);
function owner() external returns (address);
function structs(address,uint256) external returns (uint256);
function err0() external;
function err1() external;
function err2(uint256,uint256) external;
function newSt(uint256) external returns (uint256);
function getSt(uint256) external view returns (uint256);
function removeSt(uint256) external;
}

25 changes: 25 additions & 0 deletions tests/test_code_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

from solc_select import solc_select

from slither import Slither
from slither.utils.code_generation import (
generate_interface,
)

SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CODE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "code_generation")


def test_interface_generation() -> None:
solc_select.switch_global_version("0.8.4", always_install=True)

sl = Slither(os.path.join(CODE_TEST_ROOT, "CodeGeneration.sol"))

actual = generate_interface(sl.get_contract_from_name("TestContract")[0])
expected_path = os.path.join(CODE_TEST_ROOT, "TEST_generated_code.sol")

with open(expected_path, "r", encoding="utf-8") as file:
expected = file.read()

assert actual == expected

0 comments on commit 3c55228

Please sign in to comment.