Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Users to override tracker class #4

Merged
merged 4 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Out[1]: ["I", "am", "your", "father"]
```
DTM handles deferred fields well.
```python
# from django.db.models.query_utils import DeferredAttribute
In [1]: e = Example.objects.only("array").first()
In [2]: e.text = "I am not your father"
In [3]: e.tracker.changed
Expand All @@ -84,6 +85,17 @@ class Example(models.Model):
first = models.TextField()
second = models.TextField()
```
You can also implement your own Tracker class:
```python
from tracking_model import Tracker

class SuperTracker(Tracker):
def has_changed(self, field):
return field in self.changed

class Example(models.Model):
TRACKER_CLASS = SuperTracker
```

## Requirements
* Python >= 2.7, <= 3.11
Expand Down
27 changes: 26 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models

from tracking_model import TrackingModelMixin
from tracking_model import TrackingModelMixin, Tracker


class ModelB(TrackingModelMixin, models.Model):
Expand Down Expand Up @@ -36,3 +36,28 @@ class NarrowTrackedModel(TrackingModelMixin, models.Model):
TRACKED_FIELDS = ["first"]
first = models.TextField(null=True)
second = models.TextField(null=True)


class CustomTracker(Tracker):
def has_changed(self, field):
if field not in self.tracked_fields:
raise ValueError("%s is not tracked" % field)
return field in self.changed


class WithCustomTrackerModel(TrackingModelMixin, models.Model):
TRACKER_CLASS = CustomTracker
TRACKED_FIELDS = ["first"]
first = models.TextField(null=True)
second = models.TextField(null=True)


class InvalidTracker:
pass


class WithInvalidTrackerModel(TrackingModelMixin, models.Model):
TRACKER_CLASS = InvalidTracker
TRACKED_FIELDS = ["first"]
first = models.TextField(null=True)
second = models.TextField(null=True)
34 changes: 33 additions & 1 deletion tests/test_tracking_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from django.db.models.query_utils import DeferredAttribute
from django.test import TestCase

from .models import ModelA, ModelB, SignalModel, MutableModel, NarrowTrackedModel
from .models import (
ModelA,
ModelB,
SignalModel,
MutableModel,
NarrowTrackedModel,
WithCustomTrackerModel,
WithInvalidTrackerModel,
)
from .signals import *


Expand Down Expand Up @@ -211,3 +219,27 @@ def test_only_track_first(self):
self.obj.first = "Ciao ciao"
self.obj.second = "Italiano"
self.assertDictEqual(self.obj.tracker.changed, {"first": "Ciao"})


class OverrideTrackerTests(TestCase):
def test_tracking_mixin_raises_error_if_tracker_class_is_invalid(self):
with self.assertRaises(TypeError) as e:
WithInvalidTrackerModel(first="Joh", second="Doe").tracker

self.assertEqual(
str(e.exception),
"TRACKER_CLASS must be a subclass of Tracker.",
)

def test_instance_can_use_new_methods_of_tracker_class(self):
instance = WithCustomTrackerModel(first="John", second="Doe")
instance.first = "Mary"
instance.second = "Jane"
self.assertEqual(instance.tracker.has_changed("first"), True)

with self.assertRaises(ValueError) as e:
instance.tracker.has_changed("second")
self.assertEqual(
str(e.exception),
"second is not tracked",
)
2 changes: 1 addition & 1 deletion tracking_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .mixins import TrackingModelMixin
from .mixins import TrackingModelMixin, Tracker
35 changes: 20 additions & 15 deletions tracking_model/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(self, instance):


class TrackingModelMixin(object):

TRACKED_FIELDS = None
TRACKER_CLASS = Tracker

def __init__(self, *args, **kwargs):
super(TrackingModelMixin, self).__init__(*args, **kwargs)
Expand All @@ -22,12 +22,18 @@ def tracker(self):
if hasattr(self._state, "_tracker"):
tracker = self._state._tracker
else:
# validate possibility of changing tracker class
if not issubclass(self.TRACKER_CLASS, Tracker):
raise TypeError("TRACKER_CLASS must be a subclass of Tracker.")

# populate tracked fields for the first time
# by default all fields
if not self.TRACKED_FIELDS:
instance_class = type(self)
instance_class.TRACKED_FIELDS = {f.attname for f in instance_class._meta.concrete_fields}
tracker = self._state._tracker = Tracker(self)
instance_class.TRACKED_FIELDS = {
f.attname for f in instance_class._meta.concrete_fields
}
tracker = self._state._tracker = self.TRACKER_CLASS(self)
return tracker

def save(
Expand All @@ -45,17 +51,16 @@ def save(
self.tracker.changed = {}

def __setattr__(self, name, value):
if hasattr(self, "_initialized"):
if name in self.tracker.tracked_fields:
if name not in self.tracker.changed:
if name in self.__dict__:
old_value = getattr(self, name)
if value != old_value:
self.tracker.changed[name] = old_value
else:
self.tracker.changed[name] = DeferredAttribute
else:
if value == self.tracker.changed[name]:
self.tracker.changed.pop(name)
if hasattr(self, "_initialized") and name in self.tracker.tracked_fields:
if name in self.tracker.changed:
if value == self.tracker.changed[name]:
self.tracker.changed.pop(name)

elif name in self.__dict__:
old_value = getattr(self, name)
if value != old_value:
self.tracker.changed[name] = old_value
else:
self.tracker.changed[name] = DeferredAttribute

super(TrackingModelMixin, self).__setattr__(name, value)
Loading