Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of deferred fields on django 1.10+ #317

Merged
merged 18 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python:
- 3.6
install: pip install tox-travis codecov
# positional args ({posargs}) to pass into tox.ini
script: tox -- --cov
script: tox -- --cov --cov-append
after_success: codecov
deploy:
provider: pypi
Expand Down
2 changes: 2 additions & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@
| Karl Wan Nan Wo <[email protected]>
| zyegfryed
| Radosław Jan Ganczarek <[email protected]>
| Lucas Wiman <[email protected]>
| Jack Cushman <[email protected]>
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ CHANGES

master (unreleased)
-------------------
- Fix handling of deferred attributes on Django 1.10+, fixes GH-278
- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return
correct responses for deferred fields.

3.1.2 (2018.05.09)
------------------
* Update InheritanceIterable to inherit from
ModelIterable instead of BaseIterable, fixes GH-277.


3.1.1 (2017.12.17)
------------------

Expand Down
6 changes: 6 additions & 0 deletions docs/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ Returns the value of the given field during the last save:
Returns ``None`` when the model instance isn't saved yet.

If a field is `deferred`_, calling ``previous()`` will load the previous value from the database.

.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer


has_changed
~~~~~~~~~~~
Expand All @@ -167,6 +171,8 @@ Returns ``True`` if the given field has changed since the last save. The ``has_c
The ``has_changed`` method relies on ``previous`` to determine whether a
field's values has changed.

If a field is `deferred`_ and has been assigned locally, calling ``has_changed()``
will load the previous value from the database to perform the comparison.

changed
~~~~~~~
Expand Down
30 changes: 6 additions & 24 deletions model_utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ def __iter__(self):
class InheritanceQuerySetMixin(object):
def __init__(self, *args, **kwargs):
super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs)
if django.VERSION > (1, 8):
self._iterable_class = InheritanceIterable
self._iterable_class = InheritanceIterable

def select_subclasses(self, *subclasses):
levels = self._get_maximum_depth()
levels = None
calculated_subclasses = self._get_subclasses_recurse(
self.model, levels=levels)
# if none were passed in, we can just short circuit and select all
Expand Down Expand Up @@ -151,12 +150,9 @@ def _get_subclasses_recurse(self, model, levels=None):
recursively, returning a `list` of strings representing the
relations for select_related
"""
if django.VERSION < (1, 8):
related_objects = model._meta.get_all_related_objects()
else:
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]

rels = [
rel for rel in related_objects
Expand Down Expand Up @@ -199,10 +195,7 @@ def _get_ancestors_path(self, model, levels=None):
related = parent_link.remote_field
ancestry.insert(0, related.get_accessor_name())
if levels or levels is None:
if django.VERSION < (1, 8):
parent_model = related.parent_model
else:
parent_model = related.model
parent_model = related.model
parent_link = parent_model._meta.get_ancestor_link(
self.model)
else:
Expand Down Expand Up @@ -230,17 +223,6 @@ def _get_sub_obj_recurse(self, obj, s):
def get_subclass(self, *args, **kwargs):
return self.select_subclasses().get(*args, **kwargs)

def _get_maximum_depth(self):
"""
Under Django versions < 1.6, to avoid triggering
https://code.djangoproject.com/ticket/16572 we can only look
as far as children.
"""
levels = None
if django.VERSION < (1, 6, 0):
levels = 1
return levels


class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
pass
Expand Down
132 changes: 100 additions & 32 deletions model_utils/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,73 @@ def _get_field_name(self):
return self.field_name


class DescriptorWrapper(object):

def __init__(self, field_name, descriptor, tracker_attname):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname

def __get__(self, instance, owner):
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
value = self.descriptor.__get__(instance, owner)
if was_deferred:
tracker_instance = getattr(instance, self.tracker_attname)
tracker_instance.saved_data[self.field_name] = deepcopy(value)
return value

def __set__(self, instance, value):
initialized = hasattr(instance, '_instance_intialized')
was_deferred = self.field_name in instance.get_deferred_fields()

# Sentinel attribute to detect whether we are already trying to
# set the attribute higher up the stack. This prevents infinite
# recursion when retrieving deferred values from the database.
recursion_sentinel_attname = '_setting_' + self.field_name
already_setting = hasattr(instance, recursion_sentinel_attname)

if initialized and was_deferred and not already_setting:
setattr(instance, recursion_sentinel_attname, True)
try:
# Retrieve the value to set the saved_data value.
# This will undefer the field
getattr(instance, self.field_name)
finally:
instance.__dict__.pop(recursion_sentinel_attname, None)
if hasattr(self.descriptor, '__set__'):
self.descriptor.__set__(instance, value)
else:
instance.__dict__[self.field_name] = value

@staticmethod
def cls_for_descriptor(descriptor):
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
return DescriptorWrapper


class FullDescriptorWrapper(DescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj):
self.descriptor.__delete__(obj)


class FieldInstanceTracker(object):
def __init__(self, instance, fields, field_map):
self.instance = instance
self.fields = fields
self.field_map = field_map
self.init_deferred_fields()
if django.VERSION < (1, 10):
self.init_deferred_fields()

@property
def deferred_fields(self):
return self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()

def get_field_value(self, field):
return getattr(self.instance, self.field_map[field])
Expand All @@ -54,10 +115,11 @@ def set_saved_fields(self, fields=None):
def current(self, fields=None):
"""Returns dict of current values for all tracked fields"""
if fields is None:
if self.instance._deferred_fields:
deferred_fields = self.deferred_fields
if deferred_fields:
fields = [
field for field in self.fields
if field not in self.instance._deferred_fields
if field not in deferred_fields
]
else:
fields = self.fields
Expand All @@ -67,12 +129,31 @@ def current(self, fields=None):
def has_changed(self, field):
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
if field in self.deferred_fields and field not in self.instance.__dict__:
return False
return self.previous(field) != self.get_field_value(field)
else:
raise FieldError('field "%s" not tracked' % field)

def previous(self, field):
"""Returns currently saved value of given field"""

# handle deferred fields that have not yet been loaded from the database
if self.instance.pk and field in self.deferred_fields and field not in self.saved_data:

# if the field has not been assigned locally, simply fetch and un-defer the value
if field not in self.instance.__dict__:
self.get_field_value(field)

# if the field has been assigned locally, store the local value, fetch the database value,
# store database value to saved_data, and restore the local value
else:
current_value = self.get_field_value(field)
self.instance.refresh_from_db(fields=[field])
self.saved_data[field] = deepcopy(self.get_field_value(field))
setattr(self.instance, self.field_map[field], current_value)

return self.saved_data.get(field)

def changed(self):
Expand All @@ -97,35 +178,15 @@ class FileDescriptorTracker(DescriptorMixin, FileDescriptor):
def _get_field_name(self):
return self.field.name

if django.VERSION >= (1, 8):
self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
if django.VERSION >= (1, 10):
field_obj = getattr(self.instance.__class__, field)
else:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
else:
field_tracker = DeferredAttributeTracker(
field_obj.field_name, None)
setattr(self.instance.__class__, field, field_tracker)
else:
for field in self.fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, DeferredAttribute):
self.instance._deferred_fields.add(field)

# Django 1.4
if django.VERSION >= (1, 5):
model = None
else:
model = field_obj.model_ref()

field_tracker = DeferredAttributeTracker(
field_obj.field_name, model)
setattr(self.instance.__class__, field, field_tracker)
self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
else:
field_tracker = DeferredAttributeTracker(field, type(self.instance))
setattr(self.instance.__class__, field, field_tracker)


class FieldTracker(object):
Expand All @@ -152,6 +213,12 @@ def finalize_class(self, sender, **kwargs):
if self.fields is None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
if django.VERSION >= (1, 10):
for field_name in self.fields:
descriptor = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
self.field_map = self.get_field_map(sender)
models.signals.post_init.connect(self.initialize_tracker)
self.model_class = sender
Expand All @@ -164,6 +231,7 @@ def initialize_tracker(self, sender, instance, **kwargs):
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
instance._instance_intialized = True

def patch_save(self, instance):
original_save = instance.save
Expand Down
42 changes: 42 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import unicode_literals, absolute_import

import django
from django.db import models
from django.db.models.query_utils import DeferredAttribute
from django.db.models import Manager
from django.utils.encoding import python_2_unicode_compatible
from django.utils.translation import ugettext_lazy as _
Expand Down Expand Up @@ -331,3 +333,43 @@ class CustomSoftDelete(SoftDeletableModel):
is_read = models.BooleanField(default=False)

objects = CustomSoftDeleteManager()


class StringyDescriptor(object):
"""
Descriptor that returns a string version of the underlying integer value.
"""
def __init__(self, name):
self.name = name

def __get__(self, obj, cls=None):
if obj is None:
return self
if self.name in obj.get_deferred_fields():
# This queries the database, and sets the value on the instance.
if django.VERSION < (2, 1):
DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls)
else:
DeferredAttribute(field_name=self.name).__get__(obj, cls)
return str(obj.__dict__[self.name])

def __set__(self, obj, value):
obj.__dict__[self.name] = int(value)

def __delete__(self, obj):
del obj.__dict__[self.name]


class CustomDescriptorField(models.IntegerField):
def contribute_to_class(self, cls, name, **kwargs):
super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs)
setattr(cls, name, StringyDescriptor(name))


class ModelWithCustomDescriptor(models.Model):
custom_field = CustomDescriptorField()
tracked_custom_field = CustomDescriptorField()
regular_field = models.IntegerField()
tracked_regular_field = models.IntegerField()

tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field'])
Loading