diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 0958f59450..aa02597d76 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -9,9 +9,8 @@ from crytic_compile.platform import Type as PlatformType from slither.core.cfg.scope import Scope -from slither.core.solidity_types.type import Type from slither.core.source_mapping.source_mapping import SourceMapping - +from slither.utils.using_for import USING_FOR, _merge_using_for from slither.core.declarations.function import Function, FunctionType, FunctionLanguage from slither.utils.erc import ( ERC20_signatures, @@ -50,9 +49,6 @@ LOGGER = logging.getLogger("Contract") -USING_FOR_KEY = Union[str, Type] -USING_FOR_ITEM = List[Union[Type, Function]] - class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ @@ -87,8 +83,8 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope self._type_aliases: Dict[str, "TypeAliasContract"] = {} # The only str is "*" - self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {} - self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None + self._using_for: USING_FOR = {} + self._using_for_complete: Optional[USING_FOR] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -333,24 +329,15 @@ def events_as_dict(self) -> Dict[str, "EventContract"]: ################################################################################### @property - def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: + def using_for(self) -> USING_FOR: return self._using_for @property - def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: + def using_for_complete(self) -> USING_FOR: """ - Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive + USING_FOR: Dict of merged local using for directive with top level directive """ - def _merge_using_for( - uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM] - ) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: - result = {**uf1, **uf2} - for key, value in result.items(): - if key in uf1 and key in uf2: - result[key] = value + uf1[key] - return result - if self._using_for_complete is None: result = self.using_for top_level_using_for = self.file_scope.using_for_directives diff --git a/slither/core/declarations/function_top_level.py b/slither/core/declarations/function_top_level.py index 407a8d0458..e0dcd25579 100644 --- a/slither/core/declarations/function_top_level.py +++ b/slither/core/declarations/function_top_level.py @@ -1,10 +1,11 @@ """ Function module """ -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, TYPE_CHECKING, Optional from slither.core.declarations import Function from slither.core.declarations.top_level import TopLevel +from slither.utils.using_for import USING_FOR, _merge_using_for if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit @@ -16,11 +17,25 @@ class FunctionTopLevel(Function, TopLevel): def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self._scope: "FileScope" = scope + self._using_for_complete: Optional[USING_FOR] = None @property def file_scope(self) -> "FileScope": return self._scope + @property + def using_for_complete(self) -> USING_FOR: + """ + USING_FOR: Dict of top level directive + """ + + if self._using_for_complete is None: + result = {} + for uftl in self.file_scope.using_for_directives: + result = _merge_using_for(result, uftl.using_for) + self._using_for_complete = result + return self._using_for_complete + @property def canonical_name(self) -> str: """ diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py index edf846a5b1..ca73777e55 100644 --- a/slither/core/declarations/using_for_top_level.py +++ b/slither/core/declarations/using_for_top_level.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, List, Dict, Union -from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.solidity_types.type import Type from slither.core.declarations.top_level import TopLevel +from slither.utils.using_for import USING_FOR if TYPE_CHECKING: from slither.core.scope.scope import FileScope @@ -15,5 +15,5 @@ def __init__(self, scope: "FileScope") -> None: self.file_scope: "FileScope" = scope @property - def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: + def using_for(self) -> USING_FOR: return self._using_for diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index e9d446e141..fcb6c3afa7 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, List, TYPE_CHECKING, Union, Optional, Dict +from typing import Any, List, TYPE_CHECKING, Union, Optional # pylint: disable= too-many-lines,import-outside-toplevel,too-many-branches,too-many-statements,too-many-nested-blocks from slither.core.declarations import ( @@ -13,7 +13,6 @@ SolidityVariableComposed, Structure, ) -from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.declarations.custom_error import CustomError from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_top_level import FunctionTopLevel @@ -84,6 +83,7 @@ from slither.slithir.variables import TupleVariable from slither.utils.function import get_function_id from slither.utils.type import export_nested_types_from_variable +from slither.utils.using_for import USING_FOR from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR if TYPE_CHECKING: @@ -594,11 +594,13 @@ def _convert_type_contract(ir: Member) -> Assignment: def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-locals # propagate the type node_function = node.function - using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = ( - node_function.contract.using_for_complete - if isinstance(node_function, FunctionContract) - else {} - ) + + using_for: USING_FOR = {} + if isinstance(node_function, FunctionContract): + using_for = node_function.contract.using_for_complete + elif isinstance(node_function, FunctionTopLevel): + using_for = node_function.using_for_complete + if isinstance(ir, OperationWithLValue) and ir.lvalue: # Force assignment in case of missing previous correct type if not ir.lvalue.type: @@ -1531,9 +1533,10 @@ def look_for_library_or_top_level( internalcall.lvalue = None return internalcall + lib_contract = None if isinstance(destination, FunctionContract) and destination.contract.is_library: lib_contract = destination.contract - else: + elif not isinstance(destination, FunctionTopLevel): lib_contract = contract.file_scope.get_contract_from_name(str(destination)) if lib_contract: lib_call = LibraryCall( @@ -1561,7 +1564,9 @@ def convert_to_library_or_top_level( # We use contract_declarer, because Solidity resolve the library # before resolving the inheritance. # Though we could use .contract as libraries cannot be shadowed - contract = node.function.contract_declarer + contract = ( + node.function.contract_declarer if isinstance(node.function, FunctionContract) else None + ) t = ir.destination.type if t in using_for: new_ir = look_for_library_or_top_level(contract, ir, using_for, t) diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 2d9d9a39e4..660aab1767 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -9,7 +9,7 @@ StructureContract, Function, ) -from slither.core.declarations.contract import Contract, USING_FOR_KEY +from slither.core.declarations.contract import Contract from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types import ElementaryType, TypeAliasContract @@ -23,6 +23,7 @@ from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.solidity_types.type_parsing import parse_type from slither.solc_parsing.variables.state_variable import StateVariableSolc +from slither.utils.using_for import USING_FOR_KEY LOGGER = logging.getLogger("ContractSolcParsing") diff --git a/slither/utils/using_for.py b/slither/utils/using_for.py new file mode 100644 index 0000000000..d8e6481ebf --- /dev/null +++ b/slither/utils/using_for.py @@ -0,0 +1,17 @@ +from typing import Dict, List, TYPE_CHECKING, Union +from slither.core.solidity_types.type import Type + +if TYPE_CHECKING: + from slither.core.declarations import Function + +USING_FOR_KEY = Union[str, Type] +USING_FOR_ITEM = List[Union[Type, "Function"]] +USING_FOR = Dict[USING_FOR_KEY, USING_FOR_ITEM] + + +def _merge_using_for(uf1: USING_FOR, uf2: USING_FOR) -> USING_FOR: + result = {**uf1, **uf2} + for key, value in result.items(): + if key in uf1 and key in uf2: + result[key] = value + uf1[key] + return result diff --git a/tests/unit/slithir/test_data/top_level_using_for.sol b/tests/unit/slithir/test_data/top_level_using_for.sol new file mode 100644 index 0000000000..0dcf003add --- /dev/null +++ b/tests/unit/slithir/test_data/top_level_using_for.sol @@ -0,0 +1,15 @@ +pragma solidity 0.8.24; + +library Lib { + function a(uint q) public {} +} +function c(uint z) {} + +using {Lib.a} for uint; +using {c} for uint; + +function b(uint y) { + Lib.a(4); + y.c(); + y.a(); +} \ No newline at end of file diff --git a/tests/unit/slithir/test_top_level_using_for.py b/tests/unit/slithir/test_top_level_using_for.py new file mode 100644 index 0000000000..c1b531a760 --- /dev/null +++ b/tests/unit/slithir/test_top_level_using_for.py @@ -0,0 +1,44 @@ +from pathlib import Path +from slither import Slither +from slither.core.declarations.contract import Contract +from slither.slithir.operations import LibraryCall, InternalCall + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_top_level_using_for(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.24") + slither = Slither(Path(TEST_DATA_DIR, "top_level_using_for.sol").as_posix(), solc=solc_path) + + function = slither.compilation_units[0].functions_top_level[1] + assert function.name == "b" + + # LIBRARY_CALL, dest:Lib, function:Lib.a(uint256), arguments:['4'] + first_ir = function.slithir_operations[0] + assert ( + isinstance(first_ir, LibraryCall) + and isinstance(first_ir.destination, Contract) + and first_ir.destination.name == "Lib" + and first_ir.function_name == "a" + and len(first_ir.arguments) == 1 + ) + + # INTERNAL_CALL, c(uint256)(y) + second_ir = function.slithir_operations[1] + assert ( + isinstance(second_ir, InternalCall) + and second_ir.function_name == "c" + and len(second_ir.arguments) == 1 + and second_ir.arguments[0].name == "y" + ) + + # LIBRARY_CALL, dest:Lib, function:Lib.a(uint256), arguments:['y'] + third_ir = function.slithir_operations[2] + assert ( + isinstance(third_ir, LibraryCall) + and isinstance(third_ir.destination, Contract) + and third_ir.destination.name == "Lib" + and third_ir.function_name == "a" + and len(third_ir.arguments) == 1 + and third_ir.arguments[0].name == "y" + )