Skip to content

Commit

Permalink
Merge pull request #25 from tumb1er/track_refresh_from_db
Browse files Browse the repository at this point in the history
closes #21 update initial version in refresh_from_db
  • Loading branch information
tumb1er authored Feb 25, 2019
2 parents 8f5cca5 + 9f166fd commit b6baa00
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 8 deletions.
39 changes: 33 additions & 6 deletions denormalized/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from django.db import models
from django.db.models.fields import related_descriptors
from django.db.models.signals import post_save, post_delete, post_init, class_prepared
from django.db.models import signals
from django.utils.functional import cached_property

from denormalized.tracker import PREVIOUS_VERSION_FIELD, DenormalizedTracker
Expand Down Expand Up @@ -73,18 +73,21 @@ def update_object(obj, **updates):
def contribute_to_class(self, cls, name, private_only=False, **kwargs):
super().contribute_to_class(cls, name, private_only, **kwargs)
suffix = f':{cls.__name__}:{name}'
post_init.connect(
signals.post_init.connect(
self._track_previous_version, sender=cls,
dispatch_uid=f'denormalized_track_previous:{suffix}')
post_save.connect(
signals.post_save.connect(
self._track_changes, sender=cls,
dispatch_uid=f'denormalized_update_value_on_save:{suffix}')
post_delete.connect(
signals.post_delete.connect(
self._track_changes, sender=cls,
dispatch_uid=f'denormalized_update_value_on_delete:{suffix}')
class_prepared.connect(
signals.class_prepared.connect(
self._wrap_save, sender=cls,
dispatch_uid=f'denormalized_wrap_save:{suffix}')
signals.class_prepared.connect(
self._wrap_refresh_from_db, sender=cls,
dispatch_uid=f'denormalized_wrap_refresh_from_db:{suffix}')
for tracker in self.trackers:
tracker.foreign_key = self.name

Expand All @@ -106,6 +109,30 @@ def wrapped(instance, *args, **kw):

sender.save = wrapped

# noinspection PyUnusedLocal
def _wrap_refresh_from_db(self, sender, **kwargs):
""" Wraps model refresh_from_db with initial state invalidation."""

if hasattr(sender.refresh_from_db, 'denormalized_wrapper'):
return
refresh_from_db = sender.refresh_from_db

@functools.wraps(refresh_from_db)
def wrapped(instance, *args, **kw):
""" Reset cached initial state after refresh_from_db call."""
refresh_from_db(instance, *args, **kw)
if 'fields' not in kw:
self.store_initial_state(instance)
else:
initial = getattr(instance, PREVIOUS_VERSION_FIELD)
for field in kw['fields']:
value = getattr(instance, field)
setattr(initial, field, value)

wrapped.denormalized_wrapper = True

sender.refresh_from_db = wrapped

# noinspection PyUnusedLocal
def _track_previous_version(self, sender=None, instance=None, **kwargs):
if self.__in_init:
Expand All @@ -119,7 +146,7 @@ def _track_previous_version(self, sender=None, instance=None, **kwargs):
# noinspection PyUnusedLocal
def _track_changes(self, sender=None, instance=None, signal=None,
created=None, **kwargs):
deleted = signal is post_delete
deleted = signal is signals.post_delete

changed: Dict[models.Model, IncrementalUpdates] = defaultdict(dict)

Expand Down
4 changes: 2 additions & 2 deletions denormalized/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def _get_value_from_instance(self, instance: models.Model) -> Any:
arg = self.aggregate.source_expressions[0]
value = getattr(instance, arg.name)
if isinstance(value, expressions.Expression):
instance.refresh_from_db(fields=(arg.name,))
value = getattr(instance, arg.name)
value = type(instance).objects.filter(pk=instance.pk).values_list(
arg.name, flat=True).get()
return value

def _get_full_aggregate(self,
Expand Down
33 changes: 33 additions & 0 deletions testproject/testapp/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db.models import F, Sum, Min, QuerySet, Aggregate, Q, Count, Max
from django.test import TestCase

from denormalized import tracker
from testproject.testapp import models


Expand Down Expand Up @@ -52,6 +53,11 @@ def assertPointsSum(self, obj):
Sum('points'))['points__sum']
self.assertEqual(obj.points_sum, value)

def assertInitialState(self, **fields):
initial = getattr(self.member, tracker.PREVIOUS_VERSION_FIELD)
for k, v in fields.items():
self.assertEqual(getattr(initial, k), v)

def test_track_multiple_foreign_keys(self):
""" Multiple foreign keys tracked correctly."""
team = models.Team.objects.create()
Expand Down Expand Up @@ -82,6 +88,33 @@ def test_not_tracking_non_suitable(self):

delta_mock.assert_not_called()

def test_refresh_from_db_fields(self):
""" after refresh_from_db initial field value is updated."""
models.Member.objects.filter(pk=self.member.pk).update(active=False)
self.member.refresh_from_db(fields=('group',))

# refreshing unrelated fields does not affect initial values
self.assertInitialState(active=True)

# initial value remains same after in-memory change
self.member.active = False
self.assertInitialState(active=True)

# initial value is updated after refreshing field from db
self.member.refresh_from_db(fields=('active',))

self.assertInitialState(active=False)

def test_refresh_from_db(self):
""" after refresh_from_db all initial field values updated."""
models.Member.objects.filter(pk=self.member.pk).update(active=False)
self.assertInitialState(active=True)

self.member.refresh_from_db()

self.assertInitialState(active=False)



class CountTestCase(DenormalizedTrackerTestCaseBase):
field_name = 'members_count'
Expand Down

0 comments on commit b6baa00

Please sign in to comment.