diff --git a/ChangeLog b/ChangeLog index c3781713a1..799268d7c7 100644 --- a/ChangeLog +++ b/ChangeLog @@ -9,6 +9,13 @@ Release date: TBA * Import from ``astroid.node_classes`` and ``astroid.scoped_nodes`` has been deprecated in favor of ``astroid.nodes``. Only the imports from ``astroid.nodes`` will work in astroid 3.0.0. +* Add support for arbitrary Enum subclass hierachies + + Closes PyCQA/pylint#533 + Closes PyCQA/pylint#2224 + Closes PyCQA/pylint#2626 + + What's New in astroid 2.6.7? ============================ Release date: TBA diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py index 15751f47b0..9cbeae1f59 100644 --- a/astroid/brain/brain_namedtuple_enum.py +++ b/astroid/brain/brain_namedtuple_enum.py @@ -28,12 +28,14 @@ import keyword from textwrap import dedent +import astroid from astroid import arguments, inference_tip, nodes, util from astroid.builder import AstroidBuilder, extract_node from astroid.exceptions import ( AstroidTypeError, AstroidValueError, InferenceError, + MroError, UseInferenceDefault, ) from astroid.manager import AstroidManager @@ -354,9 +356,7 @@ def __mul__(self, other): def infer_enum_class(node): """Specific inference for enums.""" - for basename in node.basenames: - # TODO: doesn't handle subclasses yet. This implementation - # is a hack to support enums. + for basename in (b for cls in node.mro() for b in cls.basenames): if basename not in ENUM_BASE_NAMES: continue if node.root().name == "enum": @@ -417,9 +417,9 @@ def name(self): # should result in some nice symbolic execution classdef += INT_FLAG_ADDITION_METHODS.format(name=target.name) - fake = AstroidBuilder(AstroidManager()).string_build(classdef)[ - target.name - ] + fake = AstroidBuilder( + AstroidManager(), apply_transforms=False + ).string_build(classdef)[target.name] fake.parent = target.parent for method in node.mymethods(): fake.locals[method.name] = [method] @@ -544,6 +544,18 @@ def infer_typing_namedtuple(node, context=None): return infer_named_tuple(node, context) +def _is_enum_subclass(cls: astroid.ClassDef) -> bool: + """Return whether cls is a subclass of an Enum.""" + try: + return any( + klass.name in ENUM_BASE_NAMES + and getattr(klass.root(), "name", None) == "enum" + for klass in cls.mro() + ) + except MroError: + return False + + AstroidManager().register_transform( nodes.Call, inference_tip(infer_named_tuple), _looks_like_namedtuple ) @@ -551,11 +563,7 @@ def infer_typing_namedtuple(node, context=None): nodes.Call, inference_tip(infer_enum), _looks_like_enum ) AstroidManager().register_transform( - nodes.ClassDef, - infer_enum_class, - predicate=lambda cls: any( - basename for basename in cls.basenames if basename in ENUM_BASE_NAMES - ), + nodes.ClassDef, infer_enum_class, predicate=_is_enum_subclass ) AstroidManager().register_transform( nodes.ClassDef, inference_tip(infer_typing_namedtuple_class), _has_namedtuple_base diff --git a/astroid/nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes.py index 84fbd09bea..aefc283f68 100644 --- a/astroid/nodes/scoped_nodes.py +++ b/astroid/nodes/scoped_nodes.py @@ -63,7 +63,7 @@ from astroid.interpreter.dunder_lookup import lookup from astroid.interpreter.objectmodel import ClassModel, FunctionModel, ModuleModel from astroid.manager import AstroidManager -from astroid.nodes import node_classes +from astroid.nodes import Const, node_classes ITER_METHODS = ("__iter__", "__getitem__") EXCEPTION_BASE_CLASSES = frozenset({"Exception", "BaseException"}) @@ -2962,7 +2962,12 @@ def _inferred_bases(self, context=None): for stmt in self.bases: try: - baseobj = next(stmt.infer(context=context.clone())) + # Find the first non-None inferred base value + baseobj = next( + b + for b in stmt.infer(context=context.clone()) + if not (isinstance(b, Const) and b.value is None) + ) except (InferenceError, StopIteration): continue if isinstance(baseobj, bases.Instance): diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 6f42f19b3c..4cbed87645 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -791,6 +791,24 @@ def __init__(self, name, enum_list): test = next(enumeration.igetattr("test")) self.assertEqual(test.value, 42) + def test_user_enum_false_positive(self): + # Test that a user-defined class named Enum is not considered a builtin enum. + ast_node = astroid.extract_node( + """ + class Enum: + pass + + class Color(Enum): + red = 1 + + Color.red #@ + """ + ) + inferred = ast_node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], astroid.Const) + self.assertEqual(inferred[0].value, 1) + def test_ignores_with_nodes_from_body_of_enum(self): code = """ import enum @@ -1051,6 +1069,91 @@ def func(self): assert isinstance(inferred, bases.Instance) assert inferred.pytype() == ".TrickyEnum.value" + def test_enum_subclass_member_name(self): + ast_node = astroid.extract_node( + """ + from enum import Enum + + class EnumSubclass(Enum): + pass + + class Color(EnumSubclass): + red = 1 + + Color.red.name #@ + """ + ) + inferred = ast_node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], astroid.Const) + self.assertEqual(inferred[0].value, "red") + + def test_enum_subclass_member_value(self): + ast_node = astroid.extract_node( + """ + from enum import Enum + + class EnumSubclass(Enum): + pass + + class Color(EnumSubclass): + red = 1 + + Color.red.value #@ + """ + ) + inferred = ast_node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], astroid.Const) + self.assertEqual(inferred[0].value, 1) + + def test_enum_subclass_member_method(self): + # See Pylint issue #2626 + ast_node = astroid.extract_node( + """ + from enum import Enum + + class EnumSubclass(Enum): + def hello_pylint(self) -> str: + return self.name + + class Color(EnumSubclass): + red = 1 + + Color.red.hello_pylint() #@ + """ + ) + inferred = ast_node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], astroid.Const) + self.assertEqual(inferred[0].value, "red") + + def test_enum_subclass_different_modules(self): + # See Pylint issue #2626 + astroid.extract_node( + """ + from enum import Enum + + class EnumSubclass(Enum): + pass + """, + "a", + ) + ast_node = astroid.extract_node( + """ + from a import EnumSubclass + + class Color(EnumSubclass): + red = 1 + + Color.red.value #@ + """ + ) + inferred = ast_node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], astroid.Const) + self.assertEqual(inferred[0].value, 1) + @unittest.skipUnless(HAS_DATEUTIL, "This test requires the dateutil library.") class DateutilBrainTest(unittest.TestCase): @@ -1568,7 +1671,7 @@ def test_typing_annotated_subscriptable(self): @test_utils.require_version(minver="3.7") def test_typing_generic_slots(self): - """Test cache reset for slots if Generic subscript is inferred.""" + """Test slots for Generic subclass.""" node = builder.extract_node( """ from typing import Generic, TypeVar @@ -1580,10 +1683,6 @@ def __init__(self, value): """ ) inferred = next(node.infer()) - assert len(inferred.slots()) == 0 - # Only after the subscript base is inferred and the inference tip applied, - # will slots contain the correct value - next(node.bases[0].infer()) slots = inferred.slots() assert len(slots) == 1 assert isinstance(slots[0], nodes.Const)