From 5c28625f506f8c0f06533d8cba8fc8a879fad606 Mon Sep 17 00:00:00 2001 From: Cameron Hurst Date: Sun, 26 Apr 2020 17:55:25 -0400 Subject: [PATCH] fix: getting versioned info works properly not with __table_args__ --- postgresql_audit/base.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index 1976bac..fdb47eb 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,4 +1,5 @@ import warnings +from collections import Sequence from contextlib import contextmanager from functools import partial from weakref import WeakSet @@ -207,9 +208,10 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - if not (hasattr(obj_or_session, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None)): + version_info = self.__get_versioned_info(obj_or_session) + if not version_info: raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = getattr(obj_or_session, "__versioned__", obj_or_session.__table_args__["versioned"]).get('exclude', []) + excluded = version_info.get('exclude', []) return bool( set([ column.name @@ -220,9 +222,22 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if hasattr(entity, '__versioned__') or getattr(obj_or_session, '__table_args__', {}).get("versioned", None) + if self.__get_versioned_info(entity) ) + def __get_versioned_info(self, entity): + v_args = getattr(entity, '__versioned__', None) + if v_args: + return v_args + table_args = getattr(entity, '__table_args__', None) + if not table_args: + return None + if isinstance(table_args, Sequence): + table_args = next((x for x in iter(table_args) if isinstance(x, dict)), None) + if not table_args: + return None + return table_args.get("info", {}).get("versioned", None) + def before_flush(self, session, flush_context, instances): if session.transaction in self._marked_transactions: return