diff --git a/slither/core/cfg/scope.py b/slither/core/cfg/scope.py index d3ac4e8368..3f9959851a 100644 --- a/slither/core/cfg/scope.py +++ b/slither/core/cfg/scope.py @@ -7,8 +7,10 @@ # pylint: disable=too-few-public-methods class Scope: - def __init__(self, is_checked: bool, is_yul: bool, scope: Union["Scope", "Function"]) -> None: + def __init__( + self, is_checked: bool, is_yul: bool, parent_scope: Union["Scope", "Function"] + ) -> None: self.nodes: List["Node"] = [] self.is_checked = is_checked self.is_yul = is_yul - self.father = scope + self.father = parent_scope diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index d853e2332d..35ca51aebe 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -354,7 +354,7 @@ def _new_yul_block( ################################################################################### ################################################################################### - def _parse_if(self, if_statement: Dict, node: NodeSolc) -> NodeSolc: + def _parse_if(self, if_statement: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? falseStatement = None @@ -362,21 +362,15 @@ def _parse_if(self, if_statement: Dict, node: NodeSolc) -> NodeSolc: condition = if_statement["condition"] # Note: check if the expression could be directly # parsed here - condition_node = self._new_node( - NodeType.IF, condition["src"], node.underlying_node.scope - ) + condition_node = self._new_node(NodeType.IF, condition["src"], scope) condition_node.add_unparsed_expression(condition) link_underlying_nodes(node, condition_node) - true_scope = Scope( - node.underlying_node.scope.is_checked, False, node.underlying_node.scope - ) + true_scope = Scope(scope.is_checked, False, scope) trueStatement = self._parse_statement( if_statement["trueBody"], condition_node, true_scope ) if "falseBody" in if_statement and if_statement["falseBody"]: - false_scope = Scope( - node.underlying_node.scope.is_checked, False, node.underlying_node.scope - ) + false_scope = Scope(scope.is_checked, False, scope) falseStatement = self._parse_statement( if_statement["falseBody"], condition_node, false_scope ) @@ -385,22 +379,16 @@ def _parse_if(self, if_statement: Dict, node: NodeSolc) -> NodeSolc: condition = children[0] # Note: check if the expression could be directly # parsed here - condition_node = self._new_node( - NodeType.IF, condition["src"], node.underlying_node.scope - ) + condition_node = self._new_node(NodeType.IF, condition["src"], scope) condition_node.add_unparsed_expression(condition) link_underlying_nodes(node, condition_node) - true_scope = Scope( - node.underlying_node.scope.is_checked, False, node.underlying_node.scope - ) + true_scope = Scope(scope.is_checked, False, scope) trueStatement = self._parse_statement(children[1], condition_node, true_scope) if len(children) == 3: - false_scope = Scope( - node.underlying_node.scope.is_checked, False, node.underlying_node.scope - ) + false_scope = Scope(scope.is_checked, False, scope) falseStatement = self._parse_statement(children[2], condition_node, false_scope) - endIf_node = self._new_node(NodeType.ENDIF, if_statement["src"], node.underlying_node.scope) + endIf_node = self._new_node(NodeType.ENDIF, if_statement["src"], scope) link_underlying_nodes(trueStatement, endIf_node) if falseStatement: @@ -409,32 +397,26 @@ def _parse_if(self, if_statement: Dict, node: NodeSolc) -> NodeSolc: link_underlying_nodes(condition_node, endIf_node) return endIf_node - def _parse_while(self, whilte_statement: Dict, node: NodeSolc) -> NodeSolc: + def _parse_while(self, whilte_statement: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: # WhileStatement = 'while' '(' Expression ')' Statement - node_startWhile = self._new_node( - NodeType.STARTLOOP, whilte_statement["src"], node.underlying_node.scope - ) + node_startWhile = self._new_node(NodeType.STARTLOOP, whilte_statement["src"], scope) - body_scope = Scope(node.underlying_node.scope.is_checked, False, node.underlying_node.scope) + body_scope = Scope(scope.is_checked, False, scope) if self.is_compact_ast: node_condition = self._new_node( - NodeType.IFLOOP, whilte_statement["condition"]["src"], node.underlying_node.scope + NodeType.IFLOOP, whilte_statement["condition"]["src"], scope ) node_condition.add_unparsed_expression(whilte_statement["condition"]) statement = self._parse_statement(whilte_statement["body"], node_condition, body_scope) else: children = whilte_statement[self.get_children("children")] expression = children[0] - node_condition = self._new_node( - NodeType.IFLOOP, expression["src"], node.underlying_node.scope - ) + node_condition = self._new_node(NodeType.IFLOOP, expression["src"], scope) node_condition.add_unparsed_expression(expression) statement = self._parse_statement(children[1], node_condition, body_scope) - node_endWhile = self._new_node( - NodeType.ENDLOOP, whilte_statement["src"], node.underlying_node.scope - ) + node_endWhile = self._new_node(NodeType.ENDLOOP, whilte_statement["src"], scope) link_underlying_nodes(node, node_startWhile) link_underlying_nodes(node_startWhile, node_condition) @@ -562,7 +544,7 @@ def has_hint(key): return pre, cond, post, body - def _parse_for(self, statement: Dict, node: NodeSolc) -> NodeSolc: + def _parse_for(self, statement: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: # ForStatement = 'for' '(' (SimpleStatement)? ';' (Expression)? ';' (ExpressionStatement)? ')' Statement if self.is_compact_ast: @@ -570,17 +552,13 @@ def _parse_for(self, statement: Dict, node: NodeSolc) -> NodeSolc: else: pre, cond, post, body = self._parse_for_legacy_ast(statement) - node_startLoop = self._new_node( - NodeType.STARTLOOP, statement["src"], node.underlying_node.scope - ) - node_endLoop = self._new_node( - NodeType.ENDLOOP, statement["src"], node.underlying_node.scope - ) + node_startLoop = self._new_node(NodeType.STARTLOOP, statement["src"], scope) + node_endLoop = self._new_node(NodeType.ENDLOOP, statement["src"], scope) - last_scope = node.underlying_node.scope + last_scope = scope if pre: - pre_scope = Scope(node.underlying_node.scope.is_checked, False, last_scope) + pre_scope = Scope(scope.is_checked, False, last_scope) last_scope = pre_scope node_init_expression = self._parse_statement(pre, node, pre_scope) link_underlying_nodes(node_init_expression, node_startLoop) @@ -588,7 +566,7 @@ def _parse_for(self, statement: Dict, node: NodeSolc) -> NodeSolc: link_underlying_nodes(node, node_startLoop) if cond: - cond_scope = Scope(node.underlying_node.scope.is_checked, False, last_scope) + cond_scope = Scope(scope.is_checked, False, last_scope) last_scope = cond_scope node_condition = self._new_node(NodeType.IFLOOP, cond["src"], cond_scope) node_condition.add_unparsed_expression(cond) @@ -599,7 +577,7 @@ def _parse_for(self, statement: Dict, node: NodeSolc) -> NodeSolc: node_condition = None node_beforeBody = node_startLoop - body_scope = Scope(node.underlying_node.scope.is_checked, False, last_scope) + body_scope = Scope(scope.is_checked, False, last_scope) last_scope = body_scope node_body = self._parse_statement(body, node_beforeBody, body_scope) @@ -619,14 +597,10 @@ def _parse_for(self, statement: Dict, node: NodeSolc) -> NodeSolc: return node_endLoop - def _parse_dowhile(self, do_while_statement: Dict, node: NodeSolc) -> NodeSolc: + def _parse_dowhile(self, do_while_statement: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: - node_startDoWhile = self._new_node( - NodeType.STARTLOOP, do_while_statement["src"], node.underlying_node.scope - ) - condition_scope = Scope( - node.underlying_node.scope.is_checked, False, node.underlying_node.scope - ) + node_startDoWhile = self._new_node(NodeType.STARTLOOP, do_while_statement["src"], scope) + condition_scope = Scope(scope.is_checked, False, scope) if self.is_compact_ast: node_condition = self._new_node( @@ -644,7 +618,7 @@ def _parse_dowhile(self, do_while_statement: Dict, node: NodeSolc) -> NodeSolc: node_condition.add_unparsed_expression(expression) statement = self._parse_statement(children[1], node_condition, condition_scope) - body_scope = Scope(node.underlying_node.scope.is_checked, False, condition_scope) + body_scope = Scope(scope.is_checked, False, condition_scope) node_endDoWhile = self._new_node(NodeType.ENDLOOP, do_while_statement["src"], body_scope) link_underlying_nodes(node, node_startDoWhile) @@ -724,14 +698,12 @@ def _construct_try_expression(self, externalCall: Dict, parameters_list: Dict) - return ret - def _parse_try_catch(self, statement: Dict, node: NodeSolc) -> NodeSolc: + def _parse_try_catch(self, statement: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: externalCall = statement.get("externalCall", None) if externalCall is None: raise ParsingError(f"Try/Catch not correctly parsed by Slither {statement}") - catch_scope = Scope( - node.underlying_node.scope.is_checked, False, node.underlying_node.scope - ) + catch_scope = Scope(scope.is_checked, False, scope) new_node = self._new_node(NodeType.TRY, statement["src"], catch_scope) clauses = statement.get("clauses", []) # the first clause is the try scope @@ -748,18 +720,20 @@ def _parse_try_catch(self, statement: Dict, node: NodeSolc) -> NodeSolc: # we set the parameters (e.g. data in this case. catch(string memory data) ...) # to be initialized so they are not reported by the uninitialized-local-variables detector if index >= 1: - self._parse_catch(clause, node, True) + self._parse_catch(clause, node, catch_scope, True) else: # the parameters for the try scope were already added in _construct_try_expression - self._parse_catch(clause, node, False) + self._parse_catch(clause, node, catch_scope, False) return node - def _parse_catch(self, statement: Dict, node: NodeSolc, add_param: bool) -> NodeSolc: + def _parse_catch( + self, statement: Dict, node: NodeSolc, scope: Scope, add_param: bool + ) -> NodeSolc: block = statement.get("block", None) if block is None: raise ParsingError(f"Catch not correctly parsed by Slither {statement}") - try_scope = Scope(node.underlying_node.scope.is_checked, False, node.underlying_node.scope) + try_scope = Scope(scope.is_checked, False, scope) try_node = self._new_node(NodeType.CATCH, statement["src"], try_scope) link_underlying_nodes(node, try_node) @@ -777,7 +751,7 @@ def _parse_catch(self, statement: Dict, node: NodeSolc, add_param: bool) -> Node return self._parse_statement(block, try_node, try_scope) - def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSolc: + def _parse_variable_definition(self, statement: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: try: local_var = LocalVariable() local_var.set_function(self._function) @@ -787,9 +761,7 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol self._add_local_variable(local_var_parser) # local_var.analyze(self) - new_node = self._new_node( - NodeType.VARIABLE, statement["src"], node.underlying_node.scope - ) + new_node = self._new_node(NodeType.VARIABLE, statement["src"], scope) new_node.underlying_node.add_variable_declaration(local_var) link_underlying_nodes(node, new_node) return new_node @@ -819,7 +791,7 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol "declarations": [variable], "initialValue": init, } - new_node = self._parse_variable_definition(new_statement, new_node) + new_node = self._parse_variable_definition(new_statement, new_node, scope) else: # If we have @@ -841,7 +813,7 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol variables.append(variable) new_node = self._parse_variable_definition_init_tuple( - new_statement, i, new_node + new_statement, i, new_node, scope ) i = i + 1 @@ -873,9 +845,7 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol "typeDescriptions": {"typeString": "tuple()"}, } node = new_node - new_node = self._new_node( - NodeType.EXPRESSION, statement["src"], node.underlying_node.scope - ) + new_node = self._new_node(NodeType.EXPRESSION, statement["src"], scope) new_node.add_unparsed_expression(expression) link_underlying_nodes(node, new_node) @@ -906,7 +876,7 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol self.get_children("children"): [variable, init], } - new_node = self._parse_variable_definition(new_statement, new_node) + new_node = self._parse_variable_definition(new_statement, new_node, scope) else: # If we have # var (a, b) = f() @@ -925,7 +895,7 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol variables.append(variable) new_node = self._parse_variable_definition_init_tuple( - new_statement, i, new_node + new_statement, i, new_node, scope ) i = i + 1 var_identifiers = [] @@ -955,16 +925,14 @@ def _parse_variable_definition(self, statement: Dict, node: NodeSolc) -> NodeSol ], } node = new_node - new_node = self._new_node( - NodeType.EXPRESSION, statement["src"], node.underlying_node.scope - ) + new_node = self._new_node(NodeType.EXPRESSION, statement["src"], scope) new_node.add_unparsed_expression(expression) link_underlying_nodes(node, new_node) return new_node def _parse_variable_definition_init_tuple( - self, statement: Dict, index: int, node: NodeSolc + self, statement: Dict, index: int, node: NodeSolc, scope ) -> NodeSolc: local_var = LocalVariableInitFromTuple() local_var.set_function(self._function) @@ -974,7 +942,7 @@ def _parse_variable_definition_init_tuple( self._add_local_variable(local_var_parser) - new_node = self._new_node(NodeType.VARIABLE, statement["src"], node.underlying_node.scope) + new_node = self._new_node(NodeType.VARIABLE, statement["src"], scope) new_node.underlying_node.add_variable_declaration(local_var) link_underlying_nodes(node, new_node) return new_node @@ -995,15 +963,15 @@ def _parse_statement( name = statement[self.get_key()] # SimpleStatement = VariableDefinition | ExpressionStatement if name == "IfStatement": - node = self._parse_if(statement, node) + node = self._parse_if(statement, node, scope) elif name == "WhileStatement": - node = self._parse_while(statement, node) + node = self._parse_while(statement, node, scope) elif name == "ForStatement": - node = self._parse_for(statement, node) + node = self._parse_for(statement, node, scope) elif name == "Block": - node = self._parse_block(statement, node) + node = self._parse_block(statement, node, scope) elif name == "UncheckedBlock": - node = self._parse_unchecked_block(statement, node) + node = self._parse_unchecked_block(statement, node, scope) elif name == "InlineAssembly": # Added with solc 0.6 - the yul code is an AST if "AST" in statement and not self.compilation_unit.core.skip_assembly: @@ -1025,7 +993,7 @@ def _parse_statement( link_underlying_nodes(node, asm_node) node = asm_node elif name == "DoWhileStatement": - node = self._parse_dowhile(statement, node) + node = self._parse_dowhile(statement, node, scope) # For Continue / Break / Return / Throw # The is fixed later elif name == "Continue": @@ -1066,7 +1034,7 @@ def _parse_statement( link_underlying_nodes(node, new_node) node = new_node elif name in ["VariableDefinitionStatement", "VariableDeclarationStatement"]: - node = self._parse_variable_definition(statement, node) + node = self._parse_variable_definition(statement, node, scope) elif name == "ExpressionStatement": # assert len(statement[self.get_children('expression')]) == 1 # assert not 'attributes' in statement @@ -1080,7 +1048,7 @@ def _parse_statement( link_underlying_nodes(node, new_node) node = new_node elif name == "TryStatement": - node = self._parse_try_catch(statement, node) + node = self._parse_try_catch(statement, node, scope) # elif name == 'TryCatchClause': # self._parse_catch(statement, node) elif name == "RevertStatement": @@ -1097,7 +1065,7 @@ def _parse_statement( return node - def _parse_block(self, block: Dict, node: NodeSolc, check_arithmetic: bool = False) -> NodeSolc: + def _parse_block(self, block: Dict, node: NodeSolc, scope: Scope) -> NodeSolc: """ Return: Node @@ -1109,13 +1077,12 @@ def _parse_block(self, block: Dict, node: NodeSolc, check_arithmetic: bool = Fal else: statements = block[self.get_children("children")] - check_arithmetic = check_arithmetic | node.underlying_node.scope.is_checked - new_scope = Scope(check_arithmetic, False, node.underlying_node.scope) + new_scope = Scope(scope.is_checked, False, scope) for statement in statements: node = self._parse_statement(statement, node, new_scope) return node - def _parse_unchecked_block(self, block: Dict, node: NodeSolc): + def _parse_unchecked_block(self, block: Dict, node: NodeSolc, scope): """ Return: Node @@ -1127,7 +1094,8 @@ def _parse_unchecked_block(self, block: Dict, node: NodeSolc): else: statements = block[self.get_children("children")] - new_scope = Scope(False, False, node.underlying_node.scope) + new_scope = Scope(False, False, scope) + for statement in statements: node = self._parse_statement(statement, node, new_scope) return node @@ -1148,8 +1116,7 @@ def _parse_cfg(self, cfg: Dict) -> None: self._function.is_empty = True else: self._function.is_empty = False - check_arithmetic = self.compilation_unit.solc_version >= "0.8.0" - self._parse_block(cfg, node, check_arithmetic=check_arithmetic) + self._parse_block(cfg, node, self.underlying_function) self._remove_incorrect_edges() self._remove_alone_endif() diff --git a/tests/unit/core/test_arithmetic.py b/tests/unit/core/test_arithmetic.py index 6de63d7674..760b6d3976 100644 --- a/tests/unit/core/test_arithmetic.py +++ b/tests/unit/core/test_arithmetic.py @@ -2,7 +2,8 @@ from slither import Slither from slither.utils.arithmetic import unchecked_arithemtic_usage - +from slither.slithir.operations import Binary, Unary, Assignment +from slither.slithir.variables.temporary import TemporaryVariable TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" / "arithmetic_usage" @@ -14,3 +15,21 @@ def test_arithmetic_usage(solc_binary_path) -> None: assert { f.source_mapping.content_hash for f in unchecked_arithemtic_usage(slither.contracts[0]) } == {"2b4bc73cf59d486dd9043e840b5028b679354dd9", "e4ecd4d0fda7e762d29aceb8425f2c5d4d0bf962"} + + +def test_scope_is_checked(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.15") + slither = Slither(Path(TEST_DATA_DIR, "unchecked_scope.sol").as_posix(), solc=solc_path) + func = slither.get_contract_from_name("TestScope")[0].get_function_from_full_name( + "scope(uint256)" + ) + bin_op_is_checked = {} + for node in func.nodes: + for op in node.irs: + if isinstance(op, (Binary, Unary)): + bin_op_is_checked[op.lvalue] = op.node.scope.is_checked + if isinstance(op, Assignment) and isinstance(op.rvalue, TemporaryVariable): + if op.lvalue.name.startswith("checked"): + assert bin_op_is_checked[op.rvalue] is True + else: + assert bin_op_is_checked[op.rvalue] is False diff --git a/tests/unit/core/test_data/arithmetic_usage/unchecked_scope.sol b/tests/unit/core/test_data/arithmetic_usage/unchecked_scope.sol new file mode 100644 index 0000000000..f0cee1f109 --- /dev/null +++ b/tests/unit/core/test_data/arithmetic_usage/unchecked_scope.sol @@ -0,0 +1,17 @@ +contract TestScope { + function scope(uint256 x) public { + uint checked1 = x - x; + unchecked { + uint unchecked1 = x - x; + if (true) { + uint unchecked2 = x - x; + + } + for (uint i = 0; i < 10; i++) { + uint unchecked3 = x - x; + + } + } + uint checked2 = x - x; + } +} \ No newline at end of file