Skip to content

Commit

Permalink
Do not override custom descriptors when present.
Browse files Browse the repository at this point in the history
This commit adds a collection of wrapper classes for tracking fields
while still using custom descriptors that may be present. This fixes
a bug where deferring a model field with a custom descriptor meant
that the descriptor was overridden in all subsequent queries.
  • Loading branch information
lucaswiman committed Apr 3, 2018
1 parent 0693ec4 commit 54cc150
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 8 deletions.
77 changes: 74 additions & 3 deletions model_utils/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,73 @@ def _get_field_name(self):
return self.field_name


class DescriptorWrapper(object):

def __init__(self, field_name, descriptor, tracker_attname):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname

def __get__(self, instance, owner):
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
if self.descriptor:
value = self.descriptor.__get__(instance, owner)
else:
value = instance.__dict__[self.field_name]
if was_deferred:
tracker_instance = getattr(instance, self.tracker_attname)
tracker_instance.saved_data[self.field_name] = deepcopy(value)
return value

@staticmethod
def cls_for_descriptor(descriptor):
has_set = hasattr(descriptor, '__set__')
has_del = hasattr(descriptor, '__delete__')
if has_set and has_del:
return FullDescriptorWrapper
elif has_set:
return SettableDescriptorWrapper
elif has_del:
return DeleteableDescriptorWrapper
else:
return DescriptorWrapper


class SettableDescriptorWrapper(DescriptorWrapper):
"""
Descriptor wrapper for descriptors with a __delete__ method.
This should not be used for descriptors
"""
def __set__(self, instance, value):
return self.descriptor.__set__(instance, value)


class DeleteableDescriptorWrapper(DescriptorWrapper):
"""
Descriptor wrapper for descriptors with a __delete__ method.
This should not be used for descriptors
"""
def __delete__(self, instance):
self.descriptor.__delete__(instance)


class FullDescriptorWrapper(SettableDescriptorWrapper, DeleteableDescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""


class FieldInstanceTracker(object):
def __init__(self, instance, fields, field_map):
self.instance = instance
self.fields = fields
self.field_map = field_map
self.init_deferred_fields()
if django.VERSION < (1, 10):
self.init_deferred_fields()

def get_field_value(self, field):
return getattr(self.instance, self.field_map[field])
Expand All @@ -54,10 +115,11 @@ def set_saved_fields(self, fields=None):
def current(self, fields=None):
"""Returns dict of current values for all tracked fields"""
if fields is None:
if self.instance._deferred_fields:
deferred_fields = self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()
if deferred_fields:
fields = [
field for field in self.fields
if field not in self.instance._deferred_fields
if field not in deferred_fields
]
else:
fields = self.fields
Expand Down Expand Up @@ -135,6 +197,15 @@ def finalize_class(self, sender, **kwargs):
if self.fields is None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
if django.VERSION >= (1, 10):
for field_name in self.fields:
if django.VERSION >= (1, 10):
descriptor = getattr(sender, field_name)
else:
descriptor = sender.__dict__.get(field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
self.field_map = self.get_field_map(sender)
models.signals.post_init.connect(self.initialize_tracker)
self.model_class = sender
Expand Down
15 changes: 12 additions & 3 deletions tests/test_fields/test_field_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,22 @@ def test_with_deferred(self):
self.instance.number = 1
self.instance.save()
item = list(self.tracked_class.objects.only('name').all())[0]
self.assertTrue(item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue(item.get_deferred_fields())
else:
self.assertTrue(item._deferred_fields)

self.assertEqual(item.tracker.previous('number'), None)
self.assertTrue('number' in item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue('number' in item.get_deferred_fields())
else:
self.assertTrue('number' in item._deferred_fields)

self.assertEqual(item.number, 1)
self.assertTrue('number' not in item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue('number' not in item.get_deferred_fields())
else:
self.assertTrue('number' not in item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), 1)
self.assertFalse(item.tracker.has_changed('number'))

Expand Down
11 changes: 9 additions & 2 deletions tests/test_models/test_deferred_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import unicode_literals

import django
from django.test import TestCase

from tests.models import ModelWithCustomDescriptor
Expand Down Expand Up @@ -30,9 +31,15 @@ def test_custom_descriptor_works(self):
def test_deferred(self):
instance = ModelWithCustomDescriptor.objects.only('id').get(
pk=self.instance.pk)
self.assertIn('custom_field', instance.get_deferred_fields())
if django.VERSION >= (1, 10):
self.assertIn('custom_field', instance.get_deferred_fields())
else:
self.assertIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.custom_field, '1')
self.assertNotIn('custom_field', instance.get_deferred_fields())
if django.VERSION >= (1, 10):
self.assertNotIn('custom_field', instance.get_deferred_fields())
else:
self.assertNotIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.regular_field, 1)
self.assertEqual(instance.tracked_custom_field, '1')
self.assertEqual(instance.tracked_regular_field, 1)
Expand Down

0 comments on commit 54cc150

Please sign in to comment.