From d447b1ca9fd5b1f668d63eefa392062758555e82 Mon Sep 17 00:00:00 2001 From: Danny Cooper Date: Mon, 13 Nov 2023 17:55:58 +0000 Subject: [PATCH] Correctly resolve mro for __getattr__ in cached properties --- src/attr/_make.py | 56 +++++++++++--- tests/test_slots.py | 177 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 220 insertions(+), 13 deletions(-) diff --git a/src/attr/_make.py b/src/attr/_make.py index a0246137f..c02d05556 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -599,18 +599,46 @@ def _transform_attrs( return _Attributes((AttrsClass(attrs), base_attrs, base_attr_map)) -def _make_cached_property_getattr(cached_properties, original_getattr=None): - def __getattr__(instance, item: str): - func = cached_properties.get(item) - if func is not None: - result = func(instance) - _obj_setattr(instance, item, result) - return result - if original_getattr is not None: - return original_getattr(instance, item) - raise AttributeError(item) +def _make_cached_property_getattr( + cached_properties, + cls, +): + lines = [ + # Wrapped to get `__class__` into closure cell for super() + # (It will be replaced with the newly constructed class after construction). + "def wrapper(_cls, cached_properties, _cached_setattr_get):", + " __class__ = _cls", + " def __getattr__(self, item):", + " func = cached_properties.get(item)", + " if func is not None:", + " result = func(self)", + " _setter = _cached_setattr_get(self)", + " _setter(item, result)", + " return result", + " if '__attrs_original_getattr__' in vars(__class__):", + " return __class__.__attrs_original_getattr__(self, item)", + " if hasattr(super(), '__getattr__'):", + " return super().__getattr__(item)", + " original_error = f\"'{self.__class__.__name__}' object has no attribute '{item}'\"", + " raise AttributeError(original_error)", + " return __getattr__", + "__getattr__ = wrapper(_cls, cached_properties, _cached_setattr_get)", + ] - return __getattr__ + unique_filename = _generate_unique_filename(cls, "getattr") + + glob = { + "cached_properties": cached_properties, + "_cached_setattr_get": _obj_setattr.__get__, + "_cls": cls, + } + + return _make_method( + "__getattr__", + "\n".join(lines), + unique_filename, + glob, + ) def _frozen_setattrs(self, name, value): @@ -898,8 +926,12 @@ def _create_slots_class(self): if annotation is not inspect.Parameter.empty: cd["__annotations__"][name] = annotation + original_getattr = cd.get("__getattr__") + if original_getattr is not None: + cd["__attrs_original_getattr__"] = original_getattr + cd["__getattr__"] = _make_cached_property_getattr( - cached_properties, cd.get("__getattr__") + cached_properties, self._cls ) # We only add the names of attributes that aren't inherited. diff --git a/tests/test_slots.py b/tests/test_slots.py index c8e3cefb7..26365ab0d 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -751,6 +751,23 @@ def f(self): assert "__dict__" not in dir(A) +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_cached_property_works_on_frozen_isntances(): + """ + Infers type of cached property. + """ + + @attrs.frozen(slots=True) + class A: + x: int + + @functools.cached_property + def f(self) -> int: + return self.x + + assert A(x=1).f == 1 + + @pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") def test_slots_cached_property_infers_type(): """ @@ -768,10 +785,168 @@ def f(self) -> int: assert A.__annotations__ == {"x": int, "f": int} +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_cached_property_with_empty_getattr_raises_attribute_error_of_requested(): + """ + Ensures error information is not lost. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + a = A(1) + with pytest.raises( + AttributeError, match="'A' object has no attribute 'z'" + ): + a.z + + +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_cached_property_with_getattr_calls_getattr_for_missing_attributes(): + """ + Ensure __getattr__ implementation is maintained for non cached_properties. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + def __getattr__(self, item): + return item + + a = A(1) + assert a.f == 1 + assert a.z == "z" + + +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_getattr_in_superclass__is_called_for_missing_attributes_when_cached_property_present(): + """ + Ensure __getattr__ implementation is maintained in subclass. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + def __getattr__(self, item): + return item + + @attr.s(slots=True) + class B(A): + @functools.cached_property + def f(self): + return self.x + + b = B(1) + assert b.f == 1 + assert b.z == "z" + + +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_getattr_in_subclass_gets_superclass_cached_property(): + """ + Ensure super() in __getattr__ is not broken through cached_property re-write. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + def __getattr__(self, item): + return item + + @attr.s(slots=True) + class B(A): + @functools.cached_property + def g(self): + return self.x + + def __getattr__(self, item): + return super().__getattr__(item) + + b = B(1) + assert b.f == 1 + assert b.z == "z" + + +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_sub_class_with_independent_cached_properties_both_work(): + """ + Subclassing shouldn't break cached properties. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + @attr.s(slots=True) + class B(A): + @functools.cached_property + def g(self): + return self.x * 2 + + assert B(1).f == 1 + assert B(1).g == 2 + + +@pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") +def test_slots_with_multiple_cached_property_subclasses_works(): + """ + Multiple sub-classes shouldn't break cached properties. + """ + + @attr.s(slots=True) + class A: + x = attr.ib(kw_only=True) + + @functools.cached_property + def f(self): + return self.x + + @attr.s(slots=False) + class B: + @functools.cached_property + def g(self): + return self.x * 2 + + def __getattr__(self, item): + if hasattr(super(), "__getattr__"): + return super().__getattr__(item) + return item + + @attr.s(slots=True) + class AB(A, B): + pass + + ab = AB(x=1) + + assert ab.f == 1 + assert ab.g == 2 + assert ab.h == "h" + + @pytest.mark.skipif(not PY_3_8_PLUS, reason="cached_property is 3.8+") def test_slots_sub_class_avoids_duplicated_slots(): """ - Duplicating the slots is a wast of memory. + Duplicating the slots is a waste of memory. """ @attr.s(slots=True)