diff --git a/examples/scripts/data_dependency.sol b/examples/scripts/data_dependency.sol index 40653a2125..a4f64870af 100644 --- a/examples/scripts/data_dependency.sol +++ b/examples/scripts/data_dependency.sol @@ -102,3 +102,16 @@ contract PropagateThroughArguments { var_not_tainted = y; } } + +contract PropagateThroughReturnValue { + uint var_dependant; + uint var_state; + + function foo() public { + var_dependant = bar(); + } + + function bar() internal returns (uint) { + return (var_state); + } +} diff --git a/slither/__main__.py b/slither/__main__.py index 71c1a3b6bc..0810f34ffd 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -118,6 +118,20 @@ def process_truffle(dirname, args, detector_classes, printer_classes): return _process(slither, detector_classes, printer_classes) +def process_embark(dirname, args, detector_classes, printer_classes): + + slither = Slither(dirname, + solc=args.solc, + disable_solc_warnings=args.disable_solc_warnings, + solc_arguments=args.solc_args, + is_truffle=False, + is_embark=True, + embark_overwrite_config=args.embark_overwrite_config, + filter_paths=parse_filter_paths(args), + triage_mode=args.triage_mode) + + return _process(slither, detector_classes, printer_classes) + def process_files(filenames, args, detector_classes, printer_classes): all_contracts = [] @@ -287,6 +301,7 @@ def parse_filter_paths(args): 'filter_paths': None, 'ignore_truffle_compile': False, 'truffle_build_directory': 'build/contracts', + 'embark_overwrite_config': False, 'legacy_ast': False } @@ -442,6 +457,11 @@ def parse_args(detector_classes, printer_classes): action='store_true', default=False) + group_misc.add_argument('--embark-overwrite-config', + help=argparse.SUPPRESS, + action='store_true', + default=defaults_flag_in_config['embark_overwrite_config']) + parser.add_argument('--wiki-detectors', help=argparse.SUPPRESS, action=OutputWiki, @@ -574,6 +594,9 @@ def main_impl(all_detector_classes, all_printer_classes): elif os.path.isfile(os.path.join(filename, 'truffle.js')) or os.path.isfile(os.path.join(filename, 'truffle-config.js')): (results, number_contracts) = process_truffle(filename, args, detector_classes, printer_classes) + elif os.path.isfile(os.path.join(filename, 'embark.json')): + (results, number_contracts) = process_embark(filename, args, detector_classes, printer_classes) + elif os.path.isdir(filename) or len(globbed_filenames) > 0: extension = "*.sol" if not args.solc_ast else "*.json" filenames = glob.glob(os.path.join(filename, extension)) diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index f439c4b12d..e9f6f2c68d 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -4,7 +4,7 @@ from slither.core.declarations import (Contract, Enum, Function, SolidityFunction, SolidityVariable, SolidityVariableComposed, Structure) -from slither.slithir.operations import Index, OperationWithLValue +from slither.slithir.operations import Index, OperationWithLValue, InternalCall from slither.slithir.variables import (Constant, LocalIRVariable, ReferenceVariable, ReferenceVariableSSA, StateIRVariable, TemporaryVariable, @@ -232,6 +232,8 @@ def add_dependency(lvalue, function, ir, is_protected): function.context[KEY_SSA_UNPROTECTED][lvalue] = set() if isinstance(ir, Index): read = [ir.variable_left] + elif isinstance(ir, InternalCall): + read = ir.function.return_values_ssa else: read = ir.read [function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 483a618af1..5c8849880a 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -9,10 +9,8 @@ from slither.core.declarations.solidity_variables import (SolidityFunction, SolidityVariable, SolidityVariableComposed) -from slither.core.expressions.identifier import Identifier -from slither.core.expressions.index_access import IndexAccess -from slither.core.expressions.member_access import MemberAccess -from slither.core.expressions.unary_operation import UnaryOperation +from slither.core.expressions import (Identifier, IndexAccess, MemberAccess, + UnaryOperation) from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.state_variable import StateVariable @@ -43,6 +41,8 @@ def __init__(self): self._parameters_ssa = [] self._returns = [] self._returns_ssa = [] + self._return_values = None + self._return_values_ssa = None self._vars_read = [] self._vars_written = [] self._state_vars_read = [] @@ -468,6 +468,38 @@ def expressions(self): self._expressions = expressions return self._expressions + @property + def return_values(self): + """ + list(Return Values): List of the return values + """ + from slither.core.cfg.node import NodeType + from slither.slithir.operations import Return + from slither.slithir.variables import Constant + + if self._return_values is None: + return_values = list() + returns = [n for n in self.nodes if n.type == NodeType.RETURN] + [return_values.extend(ir.values) for node in returns for ir in node.irs if isinstance(ir, Return)] + self._return_values = list(set([x for x in return_values if not isinstance(x, Constant)])) + return self._return_values + + @property + def return_values_ssa(self): + """ + list(Return Values in SSA form): List of the return values in ssa form + """ + from slither.core.cfg.node import NodeType + from slither.slithir.operations import Return + from slither.slithir.variables import Constant + + if self._return_values_ssa is None: + return_values_ssa = list() + returns = [n for n in self.nodes if n.type == NodeType.RETURN] + [return_values_ssa.extend(ir.values) for node in returns for ir in node.irs_ssa if isinstance(ir, Return)] + self._return_values_ssa = list(set([x for x in return_values_ssa if not isinstance(x, Constant)])) + return self._return_values_ssa + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/expressions/__init__.py b/slither/core/expressions/__init__.py index 23690aca03..42554bf0bc 100644 --- a/slither/core/expressions/__init__.py +++ b/slither/core/expressions/__init__.py @@ -6,6 +6,7 @@ from .identifier import Identifier from .index_access import IndexAccess from .literal import Literal +from .member_access import MemberAccess from .new_array import NewArray from .new_contract import NewContract from .new_elementary_type import NewElementaryType diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index 2a0ffa6a65..c4df7ff249 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -19,12 +19,12 @@ def _compute_line(source_code, start, length): Not done in an efficient way """ total_length = len(source_code) - source_code = source_code.split('\n') + source_code = source_code.splitlines(True) counter = 0 i = 0 lines = [] while counter < total_length: - counter += len(source_code[i]) +1 + counter += len(source_code[i]) i = i+1 if counter > start: lines.append(i) diff --git a/slither/slither.py b/slither/slither.py index c705c5ba4e..c24ec05a76 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -30,15 +30,23 @@ def __init__(self, contract, **kwargs): ast_format (str): ast format (default '--ast-compact-json') is_truffle (bool): is a truffle directory (default false) truffle_build_directory (str): build truffle directory (default 'build/contracts') + is_embark (bool): is an embark directory (default false) + embark_overwrite_config (bool): overwrite original config file (default false) filter_paths (list(str)): list of path to filter (default []) triage_mode (bool): if true, switch to triage mode (default false) ''' is_truffle = kwargs.get('is_truffle', False) + is_embark = kwargs.get('is_embark', False) + embark_overwrite_config = kwargs.get('embark_overwrite_config', False) + # truffle directory if is_truffle: self._init_from_truffle(contract, kwargs.get('truffle_build_directory', 'build/contracts')) + # embark directory + elif is_embark: + self._init_from_embark(contract, embark_overwrite_config) # list of files provided (see --splitted option) elif isinstance(contract, list): self._init_from_list(contract) @@ -58,6 +66,46 @@ def __init__(self, contract, **kwargs): self._analyze_contracts() + def _init_from_embark(self, contract, embark_overwrite_config): + super(Slither, self).__init__('') + plugin_name = '@trailofbits/embark-contract-info' + with open('embark.json') as f: + embark_json = json.load(f) + if embark_overwrite_config: + write_embark_json = False + if (not 'plugins' in embark_json): + embark_json['plugins'] = {plugin_name:{'flags':""}} + write_embark_json = True + elif (not plugin_name in embark_json['plugins']): + embark_json['plugins'][plugin_name] = {'flags':""} + write_embark_json = True + if write_embark_json: + process = subprocess.Popen(['npm','install', plugin_name]) + _, stderr = process.communicate() + with open('embark.json', 'w') as outfile: + json.dump(embark_json, outfile, indent=2) + else: + if (not 'plugins' in embark_json) or (not 'embark-contract-info' in embark_json['plugins']): + logger.error(red('embark-contract-info plugin was found in embark.json. Please install the plugin (see https://github.com/crytic/slither/wiki/Usage#embark), or use --embark-overwrite-config.')) + sys.exit(-1) + + process = subprocess.Popen(['embark','build','--contracts'],stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + logger.info("%s\n"%stdout.decode()) + if stderr: + # Embark might return information to stderr, but compile without issue + logger.error("%s"%stderr.decode()) + infile = os.path.join(contract, 'crytic-export', 'contracts.json') + print(infile) + if not os.path.isfile(infile): + logger.error(red('Embark did not generate the AST file. Is Embark installed (npm install -g embark)? Is embark-contract-info installed? (npm install -g embark).')) + sys.exit(-1) + with open(infile, 'r') as f: + contracts_loaded = json.load(f) + contracts_loaded = contracts_loaded['asts'] + for contract_loaded in contracts_loaded: + self._parse_contracts_from_loaded_json(contract_loaded, contract_loaded['absolutePath']) + def _init_from_truffle(self, contract, build_directory): if not os.path.isdir(os.path.join(contract, build_directory)): logger.info(red('No truffle build directory found, did you run `truffle compile`?')) diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 4fd59a5aa4..7b9607df0d 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -29,6 +29,7 @@ from slither.slithir.variables import (Constant, ReferenceVariable, TemporaryVariable) from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR +from slither.utils.function import get_function_id logger = logging.getLogger('ConvertToIR') @@ -385,7 +386,14 @@ def propagate_types(ir, node): return length if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType): return Balance(ir.variable_left, ir.lvalue) + if ir.variable_right == 'selector' and isinstance(ir.variable_left.type, Function): + assignment = Assignment(ir.lvalue, + Constant(str(get_function_id(ir.variable_left.type.full_name))), + ElementaryType('bytes4')) + assignment.lvalue.set_type(ElementaryType('bytes4')) + return assignment left = ir.variable_left + t = None if isinstance(left, (Variable, SolidityVariable)): t = ir.variable_left.type elif isinstance(left, (Contract, Enum, Structure)): @@ -404,6 +412,14 @@ def propagate_types(ir, node): ir.lvalue.set_type(elems[elem].type) else: assert isinstance(type_t, Contract) + # Allow type propagtion as a Function + # Only for reference variables + # This allows to track the selector keyword + # We dont need to check for function collision, as solc prevents the use of selector + # if there are multiple functions with the same name + f = next((f for f in type_t.functions if f.name == ir.variable_right), None) + if f: + ir.lvalue.set_type(f) elif isinstance(ir, NewArray): ir.lvalue.set_type(ir.array_type) elif isinstance(ir, NewContract): diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 837a99f762..f4f0ec2861 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -2,7 +2,7 @@ from .variable import SlithIRVariable from slither.core.children.child_node import ChildNode from slither.core.variables.variable import Variable -from slither.core.declarations import Contract, Enum, SolidityVariable +from slither.core.declarations import Contract, Enum, SolidityVariable, Function class ReferenceVariable(ChildNode, Variable): @@ -56,5 +56,14 @@ def points_to(self, points_to): def name(self): return 'REF_{}'.format(self.index) + # overide of core.variables.variables + # reference can have Function has a type + # to handle the function selector + def set_type(self, t): + if not isinstance(t, Function): + super(ReferenceVariable, self).set_type(t) + else: + self._type = t + def __str__(self): return self.name diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 12d53f4e04..39b2334d69 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -319,6 +319,18 @@ def parse_super_name(expression, is_compact_ast): return base_name+arguments +def _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context): + # nop exression + # uint; + if is_compact_ast: + value = expression['typeName'] + else: + assert 'children' not in expression + value = expression['attributes']['value'] + t = parse_type(UnknownType(value), caller_context) + + return ElementaryTypeNameExpression(t) + def parse_expression(expression, caller_context): """ @@ -520,6 +532,12 @@ def parse_expression(expression, caller_context): assert len(children) == 2 left = children[0] right = children[1] + # IndexAccess is used to describe ElementaryTypeNameExpression + # if abi.decode is used + # For example, abi.decode(data, ...(uint[]) ) + if right is None: + return _parse_elementary_type_name_expression(left, is_compact_ast, caller_context) + left_expression = parse_expression(left, caller_context) right_expression = parse_expression(right, caller_context) index = IndexAccess(left_expression, right_expression, index_type) @@ -559,16 +577,7 @@ def parse_expression(expression, caller_context): return member_access elif name == 'ElementaryTypeNameExpression': - # nop exression - # uint; - if is_compact_ast: - value = expression['typeName'] - else: - assert 'children' not in expression - value = expression['attributes']['value'] - t = parse_type(UnknownType(value), caller_context) - - return ElementaryTypeNameExpression(t) + return _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context) # NewExpression is not a root expression, it's always the child of another expression diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index c2bffdd4c2..5fa3c160e3 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -89,7 +89,7 @@ def _parse_contracts_from_loaded_json(self, data_loaded, filename): if 'sourcePaths' in data_loaded: for sourcePath in data_loaded['sourcePaths']: if os.path.isfile(sourcePath): - with open(sourcePath, encoding='utf8') as f: + with open(sourcePath, encoding='utf8', newline='') as f: source_code = f.read() self.source_code[sourcePath] = source_code @@ -152,13 +152,13 @@ def _parse_source_unit(self, data, filename): self._source_units[sourceUnit] = name if os.path.isfile(name) and not name in self.source_code: - with open(name, encoding='utf8') as f: + with open(name, encoding='utf8', newline='') as f: source_code = f.read() self.source_code[name] = source_code else: lib_name = os.path.join('node_modules', name) if os.path.isfile(lib_name) and not name in self.source_code: - with open(lib_name, encoding='utf8') as f: + with open(lib_name, encoding='utf8', newline='') as f: source_code = f.read() self.source_code[name] = source_code