From 9aae5f571338e7e25b15ed0b9c433f1d5e8f4a46 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Fri, 4 Oct 2024 11:19:28 +0100 Subject: [PATCH] Fix narrowing of IntEnum and StrEnum types Fix regression in #17866. It should still be possible to narrow IntEnum and StrEnum types, but only when types match or are disjoint. Add more logic to rule out narrowing when types are ambigous. --- mypy/checker.py | 46 +++++++++++++++++- mypy/typeops.py | 8 ---- test-data/unit/check-narrowing.test | 73 ++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 10 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a089576e6ffe1..0d8ae3da1c9f0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5984,7 +5984,9 @@ def has_no_custom_eq_checks(t: Type) -> bool: coerce_only_in_literal_context = True expr_types = [operand_types[i] for i in expr_indices] - should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) + should_narrow_by_identity = all( + map(has_no_custom_eq_checks, expr_types) + ) and not is_ambiguous_mix_of_enums(expr_types) if_map: TypeMap = {} else_map: TypeMap = {} @@ -8604,3 +8606,45 @@ def visit_starred_pattern(self, p: StarredPattern) -> None: self.lvalue = True p.capture.accept(self) self.lvalue = False + + +def is_ambiguous_mix_of_enums(types: list[Type]) -> bool: + """Do types have IntEnum/StrEnum types that are potentially overlapping with other types? + + If True, we shouldn't attempt type narrowing based on enum values, as it gets + too ambiguous. + + For example, return True if there's an 'int' type together with an IntEnum literal. + However, IntEnum together with a literal of the same IntEnum type is not ambiguous. + """ + # We need these things for this to be ambiguous: + # (1) an IntEnum or StrEnum type + # (2) either a different IntEnum/StrEnum type or a non-enum type ("") + # + # It would be slightly more correct to calculate this separately for IntEnum and + # StrEnum related types, as an IntEnum can't be confused with a StrEnum. + return len(_ambiguous_enum_variants(types)) > 1 + + +def _ambiguous_enum_variants(types: list[Type]) -> set[str]: + result = set() + for t in types: + t = get_proper_type(t) + if isinstance(t, UnionType): + result.update(_ambiguous_enum_variants(t.items)) + elif isinstance(t, Instance): + if t.last_known_value: + result.update(_ambiguous_enum_variants([t.last_known_value])) + elif t.type.is_enum and any( + base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro + ): + result.add(t.type.fullname) + elif not t.type.is_enum: + # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so + # let's be conservative + result.add("") + elif isinstance(t, LiteralType): + result.update(_ambiguous_enum_variants([t.fallback])) + else: + result.add("") + return result diff --git a/mypy/typeops.py b/mypy/typeops.py index efb5d8fa505fc..0699cda53cfa9 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -1078,14 +1078,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool """ typ = get_proper_type(typ) if isinstance(typ, Instance): - if ( - typ.type.is_enum - and name in ("__eq__", "__ne__") - and any(base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in typ.type.mro) - ): - # IntEnum and StrEnum values have non-straightfoward equality, so treat them - # as if they had custom __eq__ and __ne__ - return True method = typ.type.get(name) if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): if method.node.info: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 169523e61be73..11bbf0a950841 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2135,7 +2135,7 @@ else: # mypy: strict-equality from __future__ import annotations from typing import Any -from enum import IntEnum, StrEnum +from enum import IntEnum class IE(IntEnum): X = 1 @@ -2178,6 +2178,71 @@ def f5(x: int) -> None: reveal_type(x) # N: Revealed type is "builtins.int" else: reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" + +def f6(x: IE) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" + else: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y]" +[builtins fixtures/primitives.pyi] + +[case testNarrowingWithIntEnum2] +# mypy: strict-equality +from __future__ import annotations +from typing import Any +from enum import IntEnum, Enum + +class MyDecimal: ... + +class IE(IntEnum): + X = 1 + Y = 2 + +class IE2(IntEnum): + X = 1 + Y = 2 + +class E(Enum): + X = 1 + Y = 2 + +def f1(x: IE | MyDecimal) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.MyDecimal]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.MyDecimal]" + +def f2(x: E | bytes) -> None: + if x == E.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.E.X]" + else: + reveal_type(x) # N: Revealed type is "Union[Literal[__main__.E.Y], builtins.bytes]" + +def f3(x: IE | IE2) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + +def f4(x: IE | E) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" + elif x == E.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.E.X]" + else: + reveal_type(x) # N: Revealed type is "Union[Literal[__main__.IE.Y], Literal[__main__.E.Y]]" + +def f5(x: E | str | int) -> None: + if x == E.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.E.X]" + else: + reveal_type(x) # N: Revealed type is "Union[Literal[__main__.E.Y], builtins.str, builtins.int]" + +def f6(x: IE | Any) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, Any]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, Any]" [builtins fixtures/primitives.pyi] [case testNarrowingWithStrEnum] @@ -2205,4 +2270,10 @@ def f3(x: object) -> None: reveal_type(x) # N: Revealed type is "builtins.object" else: reveal_type(x) # N: Revealed type is "builtins.object" + +def f4(x: SE) -> None: + if x == SE.A: + reveal_type(x) # N: Revealed type is "Literal[__main__.SE.A]" + else: + reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]" [builtins fixtures/primitives.pyi]