Skip to content

Commit

Permalink
Merge pull request #1890 from crytic/fp/strict-equality
Browse files Browse the repository at this point in the history
reduce false positives for incorrect-equality detector
  • Loading branch information
montyly authored May 5, 2023
2 parents c1ae06a + abcef30 commit c915055
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 8 deletions.
7 changes: 2 additions & 5 deletions slither/core/declarations/custom_error.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, TYPE_CHECKING, Optional, Type

from slither.core.solidity_types import UserDefinedType
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.local_variable import LocalVariable
from slither.utils.type import is_underlying_type_address

if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
Expand Down Expand Up @@ -43,10 +43,7 @@ def compilation_unit(self) -> "SlitherCompilationUnit":

@staticmethod
def _convert_type_for_solidity_signature(t: Optional[Type]) -> str:
# pylint: disable=import-outside-toplevel
from slither.core.declarations import Contract

if isinstance(t, UserDefinedType) and isinstance(t.type, Contract):
if is_underlying_type_address(t):
return "address"
return str(t)

Expand Down
23 changes: 20 additions & 3 deletions slither/detectors/statements/incorrect_strict_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from slither.slithir.variables.local_variable import LocalIRVariable
from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA
from slither.utils.output import Output
from slither.utils.type import is_underlying_type_address


class IncorrectStrictEquality(AbstractDetector):
Expand Down Expand Up @@ -72,6 +73,19 @@ class IncorrectStrictEquality(AbstractDetector):
def is_direct_comparison(ir: Operation) -> bool:
return isinstance(ir, Binary) and ir.type == BinaryType.EQUAL

@staticmethod
def is_not_comparing_addresses(ir: Binary) -> bool:
"""
Comparing addresses strictly should not be flagged.
"""

if is_underlying_type_address(ir.variable_left.type) and is_underlying_type_address(
ir.variable_right.type
):
return False

return True

@staticmethod
def is_any_tainted(
variables: List[
Expand Down Expand Up @@ -108,7 +122,6 @@ def taint_balance_equalities(
):
taints.append(ir.lvalue)
if isinstance(ir, HighLevelCall):
# print(ir.function.full_name)
if (
isinstance(ir.function, Function)
and ir.function.full_name == "balanceOf(address)"
Expand All @@ -125,7 +138,6 @@ def taint_balance_equalities(
if isinstance(ir, Assignment):
if ir.rvalue in self.sources_taint:
taints.append(ir.lvalue)

return taints

# Retrieve all tainted (node, function) pairs
Expand All @@ -145,7 +157,12 @@ def tainted_equality_nodes(
for ir in node.irs_ssa:

# Filter to only tainted equality (==) comparisons
if self.is_direct_comparison(ir) and self.is_any_tainted(ir.used, taints, func):
if (
self.is_direct_comparison(ir)
# Filter out address comparisons which may occur due to lack of field sensitivity in data dependency
and self.is_not_comparing_addresses(ir)
and self.is_any_tainted(ir.used, taints, func)
):
if func not in results:
results[func] = []
results[func].append(node)
Expand Down
15 changes: 15 additions & 0 deletions slither/utils/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,18 @@ def export_return_type_from_variable(
return ret

return [variable_or_type.type]


def is_underlying_type_address(t: "Type") -> bool:
"""
Return true if the underlying type is an address
i.e. if the type is an address or a contract
"""
# pylint: disable=import-outside-toplevel
from slither.core.declarations.contract import Contract

if t == ElementaryType("address"):
return True
if isinstance(t, UserDefinedType) and isinstance(t.type, Contract):
return True
return False
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,27 @@ contract TestSolidityKeyword{

}

interface Receiver {

}
contract A {
mapping(address => Info) data;

struct Info {
uint a;
address b;
uint c;
}
function good(address b) public payable {
data[msg.sender] = Info(block.timestamp, b, msg.value);
if (data[msg.sender].b == address(0)) {
payable(msg.sender).transfer(data[msg.sender].c);
}
}
function good2(address b) public payable {
data[msg.sender] = Info(block.timestamp, b, msg.value);
if (Receiver(data[msg.sender].b) == Receiver(address(0))) {
payable(msg.sender).transfer(data[msg.sender].c);
}
}
}
Binary file not shown.

0 comments on commit c915055

Please sign in to comment.