Skip to content

Commit

Permalink
Merge pull request #1574 from crytic/contract-error-selector
Browse files Browse the repository at this point in the history
resolve error referenced as member of contract
  • Loading branch information
montyly authored Jan 10, 2023
2 parents 1c63aa1 + e21d6eb commit b8ff0b0
Show file tree
Hide file tree
Showing 28 changed files with 163 additions and 8 deletions.
25 changes: 21 additions & 4 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def integrate_value_gas(result):
###################################################################################


def propagate_type_and_convert_call(result, node):
def propagate_type_and_convert_call(result: List[Operation], node: "Node") -> List[Operation]:
"""
Propagate the types variables and convert tmp call to real call operation
"""
Expand Down Expand Up @@ -664,7 +664,24 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
if ir.variable_right == "selector" and isinstance(ir.variable_left, (CustomError)):
assignment = Assignment(
ir.lvalue,
Constant(str(get_function_id(ir.variable_left.solidity_signature))),
Constant(
str(get_function_id(ir.variable_left.solidity_signature)),
ElementaryType("bytes4"),
),
ElementaryType("bytes4"),
)
assignment.set_expression(ir.expression)
assignment.set_node(ir.node)
assignment.lvalue.set_type(ElementaryType("bytes4"))
return assignment

if isinstance(ir.variable_right, (CustomError)):
assignment = Assignment(
ir.lvalue,
Constant(
str(get_function_id(ir.variable_left.solidity_signature)),
ElementaryType("bytes4"),
),
ElementaryType("bytes4"),
)
assignment.set_expression(ir.expression)
Expand Down Expand Up @@ -736,7 +753,7 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
if f:
ir.lvalue.set_type(f)
else:
# Allow propgation for variable access through contract's nale
# Allow propgation for variable access through contract's name
# like Base_contract.my_variable
v = next(
(
Expand Down Expand Up @@ -1819,7 +1836,7 @@ def _find_source_mapping_references(irs: List[Operation]):
###################################################################################


def apply_ir_heuristics(irs, node):
def apply_ir_heuristics(irs: List[Operation], node: "Node"):
"""
Apply a set of heuristic to improve slithIR
"""
Expand Down
12 changes: 8 additions & 4 deletions slither/visitors/slithir/expression_to_slithir.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,18 @@ def _post_member_access(self, expression):
set_val(expression, expr)
return

# Early lookup to detect user defined types from other contracts definitions
# contract A { type MyInt is int}
# contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
# The logic is handled by _post_call_expression
if isinstance(expr, Contract):
# Early lookup to detect user defined types from other contracts definitions
# contract A { type MyInt is int}
# contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
# The logic is handled by _post_call_expression
if expression.member_name in expr.file_scope.user_defined_types:
set_val(expression, expr.file_scope.user_defined_types[expression.member_name])
return
# Lookup errors referred to as member of contract e.g. Test.myError.selector
if expression.member_name in expr.custom_errors_as_dict:
set_val(expression, expr.custom_errors_as_dict[expression.member_name])
return

val = ReferenceVariable(self._node)
member = Member(expr, Constant(expression.member_name), val)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 13 additions & 0 deletions tests/ast-parsing/custom-error-selector.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
contract Test {
error myError();
}

interface VM {
function expectRevert(bytes4) external;
function expectRevert(bytes calldata) external;
}
contract A {
function b(address c) public {
VM(c).expectRevert(Test.myError.selector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"Test": {},
"VM": {
"expectRevert(bytes4)": "digraph{\n}\n",
"expectRevert(bytes)": "digraph{\n}\n"
},
"A": {
"b(address)": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n"
}
}
1 change: 1 addition & 0 deletions tests/test_ast_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]:
ALL_VERSIONS,
),
Test("custom_error-0.8.4.sol", make_version(8, 4, 15)),
Test("custom-error-selector.sol", make_version(8, 4, 15)),
Test(
"top-level-0.4.0.sol",
VERSIONS_04 + VERSIONS_05 + VERSIONS_06 + ["0.7.0"],
Expand Down

0 comments on commit b8ff0b0

Please sign in to comment.