diff --git a/limpyd/database.py b/limpyd/database.py index 9ea8f1d..c341925 100644 --- a/limpyd/database.py +++ b/limpyd/database.py @@ -15,6 +15,46 @@ ) +class Command(object): + """ + Object to pass the command through middlewares + """ + __slots__ = ('name', 'args', 'kwargs',) + + def __init__(self, name, *args, **kwargs): + self.name = name + self.args = args + self.kwargs = kwargs + + def __unicode__(self): + return u"Command(name='%s', args=%s, kwargs=%s)" % (self.name, self.args, self.kwargs) + + def __str__(self): + return unicode(self).encode('utf-8') + + def __repr__(self): + return str(self) + + +class Result(object): + """ + Object to pass the command's result through middlewares + """ + __slots__ = ('value',) + + def __init__(self, value): + self.value = value + + def __unicode__(self): + return u"Result(value=%s)" % self.value + + def __str__(self): + return unicode(self).encode('utf-8') + + def __repr__(self): + return str(self) + + class RedisDatabase(object): """ A RedisDatabase regroups some models and handles the connection to Redis for @@ -27,12 +67,18 @@ class RedisDatabase(object): """ _connections = {} # class level cache discard_cache = False + middlewares = [] - def __init__(self, **connection_settings): + def __init__(self, middlewares=None, **connection_settings): self._connection = None # Instance level cache self.reset(**(connection_settings or DEFAULT_CONNECTION_SETTINGS)) + # _models keep an entry for each defined model on this database self._models = dict() + + if middlewares is not None: + self.middlewares = middlewares + super(RedisDatabase, self).__init__() def connect(self, **settings): @@ -126,3 +172,64 @@ def has_scripting(self): except: self._has_scripting = False return self._has_scripting + + @property + def prepared_middlewares(self): + """ + Load, cache and return the list of usable middlewares, as a dict with + an entry for each usable method. + { + 'pre_command': [list, of, middlewares], + 'post_command': [list, of, middlewares], + } + Middlewares must be defined while declaring the database: + database = RedisDatabase(middlewares=[ + AMiddleware(), + AnoterMiddleware(some, parameter) + ], **connection_settings) + """ + + if not hasattr(self, '_prepared_middlewares'): + + self._prepared_middlewares = { + 'pre_command': [], + 'post_command': [], + } + + for middleware in self.middlewares: + middleware.database = self + + for middleware_type in self._prepared_middlewares: + if hasattr(middleware, middleware_type): + self._prepared_middlewares[middleware_type].append(middleware) + + self._prepared_middlewares['post_command'] = self._prepared_middlewares['post_command'][::-1] + + return self._prepared_middlewares + + def run_command(self, command, context=None): + """ + Run a redis command, passing it through all defined middlewares. + The command must be a Command namedtuple + """ + if context is None: + context = {} + + result = None + + for middleware in self.prepared_middlewares['pre_command']: + result = middleware.pre_command(command, context) + if result: + break + + if result is None: + method = getattr(self.connection, "%s" % command.name) + result = method(*command.args, **command.kwargs) + + if not isinstance(result, Result): + result = Result(result) + + for middleware in self.prepared_middlewares['post_command']: + result = middleware.post_command(command, result, context) + + return result.value diff --git a/limpyd/fields.py b/limpyd/fields.py index f7f69fe..4564205 100644 --- a/limpyd/fields.py +++ b/limpyd/fields.py @@ -7,6 +7,7 @@ from redis.client import Lock from limpyd.utils import make_key, memoize_command +from limpyd.database import Command from limpyd.exceptions import * log = getLogger(__name__) @@ -125,10 +126,11 @@ def _traverse_command(self, name, *args, **kwargs): # TODO: implement instance level cache if not name in self.available_commands: raise AttributeError("%s is not an available command for %s" % (name, self.__class__.__name__)) - attr = getattr(self.connection, "%s" % name) - key = self.key - log.debug(u"Requesting %s with key %s and args %s" % (name, key, args)) - result = attr(key, *args, **kwargs) + + command = Command(name, self.key, *args, **kwargs) + context = {'sender': self, } + result = self.database.run_command(command, context) + result = self.post_command( sender=self, name=name, @@ -230,12 +232,18 @@ def __init__(self, *args, **kwargs): self._creation_order = RedisField._creation_order RedisField._creation_order += 1 - def proxy_get(self): + def proxy_get(self, _direct=False): """ - A helper to easily call the proxy_getter of the field + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis """ - getter = getattr(self, self.proxy_getter) - return getter() + if _direct: + getter = getattr(self.connection, self.proxy_getter) + return getter(self.key) + else: + getter = getattr(self, self.proxy_getter) + return getter() def proxy_set(self, value): """ @@ -464,7 +472,7 @@ def index(self, value=None): """ if self.indexable: if value is None: - value = self.proxy_get() + value = self.proxy_get(_direct=True) key = self.index_key(value) self.add_index(key) @@ -478,7 +486,7 @@ def deindex(self, value=None): """ if self.indexable: if value is None: - value = self.proxy_get() + value = self.proxy_get(_direct=True) key = self.index_key(value) self.remove_index(key) @@ -546,7 +554,7 @@ def _call_set(self, command, value, *args, **kwargs): Helper for commands that only set a value to the field. """ if self.indexable: - current = self.proxy_get() + current = self.proxy_get(_direct=True) if current != value: self.deindex(current) self.index(value) @@ -630,7 +638,7 @@ def index(self, values=None): """ if self.indexable: if values is None: - values = self.proxy_get() + values = self.proxy_get(_direct=True) for value in values: key = self.index_key(value) self.add_index(key) @@ -641,7 +649,7 @@ def deindex(self, values=None): """ if self.indexable: if not values: - values = self.proxy_get() + values = self.proxy_get(_direct=True) for value in values: key = self.index_key(value) self.remove_index(key) @@ -675,6 +683,17 @@ def zmembers(self): """ return self.zrange(0, -1) + def proxy_get(self, _direct=False): + """ + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis + """ + if _direct: + return self.connection.zrange(self.key, 0, -1) + else: + return super(SortedSetField, self).proxy_get() + def _call_zadd(self, command, *args, **kwargs): """ We do the same computation of the zadd method of StrictRedis to keep keys @@ -788,6 +807,17 @@ def lmembers(self): """ return self.lrange(0, -1) + def proxy_get(self, _direct=False): + """ + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis + """ + if _direct: + return self.connection.lrange(self.key, 0, -1) + else: + return super(ListField, self).proxy_get() + def _pushx(self, command, *args, **kwargs): """ Helper for lpushx and rpushx, that only index the new values if the list @@ -845,7 +875,7 @@ class HashField(MultiValuesField): def _call_hmset(self, command, *args, **kwargs): if self.indexable: - current = self.proxy_get() + current = self.proxy_get(_direct=True) _to_deindex = dict((k, current[k]) for k in kwargs.iterkeys() if k in current) self.deindex(_to_deindex) self.index(kwargs) @@ -853,7 +883,7 @@ def _call_hmset(self, command, *args, **kwargs): def _call_hset(self, command, key, value): if self.indexable: - current = self.proxy_get() + current = self.proxy_get(_direct=True) if value != current.get(key, None): if key in current: self.deindex({key: current[key]}) @@ -862,7 +892,7 @@ def _call_hset(self, command, key, value): def _call_hincrby(self, command, key, amount): if self.indexable: - current = self.proxy_get() + current = self.proxy_get(_direct=True) if key in current: self.deindex({key: current[key]}) result = self._traverse_command(command, key, amount) @@ -872,7 +902,7 @@ def _call_hincrby(self, command, key, amount): def _call_hdel(self, command, *args): if self.indexable: - current = self.proxy_get() + current = self.proxy_get(_direct=True) self.deindex(dict((k, current[k]) for k in args if k in current)) return self._traverse_command(command, *args) @@ -908,7 +938,7 @@ def index(self, values=None): """ if self.indexable: if values is None: - values = self.proxy_get() + values = self.proxy_get(_direct=True) for field_name, value in values.iteritems(): key = self.index_key(value, field_name) self.add_index(key) @@ -919,7 +949,7 @@ def deindex(self, values=None): """ if self.indexable: if values is None: - values = self.proxy_get() + values = self.proxy_get(_direct=True) for field_name, value in values.iteritems(): key = self.index_key(value, field_name) self.remove_index(key) @@ -938,6 +968,17 @@ class InstanceHashField(SingleValueField): _call_hset = SingleValueField._call_set _call_hdel = RedisField._del + def proxy_get(self, _direct=False): + """ + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis + """ + if _direct: + return self.connection.hget(self.key, self.name) + else: + return super(InstanceHashField, self).proxy_get() + @property def key(self): return self._instance.key diff --git a/limpyd/middlewares.py b/limpyd/middlewares.py new file mode 100644 index 0000000..9acf1e5 --- /dev/null +++ b/limpyd/middlewares.py @@ -0,0 +1,66 @@ +# -*- coding:utf-8 -*- + +from time import time + +from limpyd.exceptions import ImplementationError + + +class BaseMiddleware(object): + @property + def database(self): + return self._database + + @database.setter + def database(self, value): + if hasattr(self, '_database'): + raise ImplementationError("Cannot change the database of a middleware") + self._database = value + + # minimal pre_command method: do nothing and return None + # def pre_command(self, command, context): + # pass + + # minimal post_command method: return the given result + # def post_command(self, command, result, context): + # return result + + +class LoggingMiddleware(BaseMiddleware): + """ + Middleware that takes a logger, and log commands and their result (and time + to run the command). + """ + def __init__(self, logger, log_results=True, log_time=True): + """ + The logger must be a defined and correctly initialized one (via logging) + The log_results flag indicates if only the commands or also their result + (with duration, if log_time is True) are logged. + """ + self.logger = logger + self.log_results = log_results + self.log_time = log_time + super(LoggingMiddleware, self).__init__() + + @BaseMiddleware.database.setter + def database(self, value): + BaseMiddleware.database.fset(self, value) # super + self.database._command_logger_counter = 0 + + def pre_command(self, command, context): + self.database._command_logger_counter += 1 + context['_command_number'] = self.database._command_logger_counter + if self.log_time: + context['_start_time'] = time() + self.logger.info(u'[#%s] %s' % (context['_command_number'], str(command))) + + def post_command(self, command, result, context): + if self.log_results: + log_str = u'[#%s] %s' + log_params = [context['_command_number'], str(result)] + if self.log_time: + log_str = u'[#%s, in %0.0fµs] %s' + duration = (time() - context['_start_time']) * 1000000 + log_params.insert(1, duration) + + self.logger.info(log_str % tuple(log_params)) + return result diff --git a/run_tests.py b/run_tests.py index a861308..f6abe81 100644 --- a/run_tests.py +++ b/run_tests.py @@ -4,7 +4,7 @@ import argparse # FIXME: move tests in limpyd module, to prevent a relative import? -from tests import base, model, utils, collection, lock, fields +from tests import base, model, utils, collection, lock, fields, middlewares from tests.contrib import database, related, collection as contrib_collection @@ -37,7 +37,7 @@ else: # Run all the tests suites = [] - default_mods = [base, model, utils, collection, lock, fields, ] + default_mods = [base, model, utils, collection, lock, fields, middlewares, ] contrib_mods = [database, related, contrib_collection] for mod in default_mods + contrib_mods: suite = unittest.TestLoader().loadTestsFromModule(mod) diff --git a/tests/middlewares.py b/tests/middlewares.py new file mode 100644 index 0000000..aff060b --- /dev/null +++ b/tests/middlewares.py @@ -0,0 +1,149 @@ +# -*- coding:utf-8 -*- + +import unittest +import logging +from StringIO import StringIO + +from limpyd.middlewares import BaseMiddleware, LoggingMiddleware +from limpyd.database import RedisDatabase +from limpyd import model +from limpyd import fields + +from base import LimpydBaseTest, TEST_CONNECTION_SETTINGS + + +class ForceSetterMiddleware(BaseMiddleware): + """ + A test middleware that always save the same given value for all "set" calls + """ + def __init__(self, value): + super(ForceSetterMiddleware, self).__init__() + self.value = value + + def pre_command(self, command, context): + if command.name == 'hset': + command.kwargs = {} + command.args = (command.args[0], command.args[1], self.value) + + +class ForceGetterMiddleware(BaseMiddleware): + """ + A test middleware that always returns the same given value for all "get" calls + """ + def __init__(self, value): + super(ForceGetterMiddleware, self).__init__() + self.value = value + + def post_command(self, command, result, context): + if command.name == 'hget': + result.value = self.value + return result + + +class BaseTestModel(model.RedisModel): + abstract = True + cacheable = False + foo = fields.InstanceHashField() + + +class MiddlewareTest(LimpydBaseTest): + def test_middleware_pre_command_method_should_be_called(self): + test_database = RedisDatabase(middlewares=[ + ForceSetterMiddleware(value='BAZ'), + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_middleware_pre_command_method_should_be_called' + + instance = TestModel(foo='bar') + self.assertEqual(instance.foo.hget(), 'BAZ') + + def test_middleware_post_command_method_should_be_called(self): + test_database = RedisDatabase(middlewares=[ + ForceGetterMiddleware(value='QUX'), + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_middleware_post_command_method_should_be_called' + + instance = TestModel(foo='bar') + + # the middleware will send "QUX" + self.assertEqual(instance.foo.hget(), 'QUX') + + # but for untouched command, we got the real values + self.assertEqual(instance.hmget('foo'), ['bar']) + + def test_database_can_accept_many_middlewares(self): + test_database = RedisDatabase(middlewares=[ + ForceSetterMiddleware(value='BAZ'), + ForceGetterMiddleware(value='QUX'), + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_database_can_accept_many_middlewares' + + instance = TestModel(foo='bar') + + # the getter middleware will send "QUX" + self.assertEqual(instance.foo.hget(), 'QUX') + + # but for untouched command, we got the value set by the setter middleware + self.assertEqual(instance.hmget('foo'), ['BAZ']) + + def test_logging_middleware(self): + + logger = logging.getLogger('limpyd.tests.middlewares.test_logging_middleware') + stream = StringIO() + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(stream)) + + test_database = RedisDatabase(middlewares=[ + LoggingMiddleware(logger, log_time=False) + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_logging_middleware' + + instance = TestModel(foo='bar') + self.assertEqual(instance.foo.hget(), 'bar') + + log_lines = [line for line in stream.getvalue().split('\n') if line] + self.assertEqual(len(log_lines), 4) + self.assertEqual(log_lines[0], u"[#1] Command(name='hset', args=(u'test_logging_middleware:testmodel:1:hash', 'foo', 'bar'), kwargs={})") + self.assertEqual(log_lines[1], u"[#1] Result(value=1)") + self.assertEqual(log_lines[2], u"[#2] Command(name='hget', args=(u'test_logging_middleware:testmodel:1:hash', 'foo'), kwargs={})") + self.assertEqual(log_lines[3], u"[#2] Result(value=bar)") + + def test_logging_middleware_with_another(self): + + logger = logging.getLogger('limpyd.tests.middlewares.test_logging_middleware_with_another') + stream = StringIO() + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(stream)) + + test_database = RedisDatabase(middlewares=[ + ForceSetterMiddleware(value='BAZ'), + LoggingMiddleware(logger, log_time=False) + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_logging_middleware_with_another' + + instance = TestModel(foo='bar') + self.assertEqual(instance.foo.hget(), 'BAZ') + + log_lines = [line for line in stream.getvalue().split('\n') if line] + self.assertEqual(len(log_lines), 4) + self.assertEqual(log_lines[0], u"[#1] Command(name='hset', args=(u'test_logging_middleware_with_another:testmodel:1:hash', 'foo', 'BAZ'), kwargs={})") + self.assertEqual(log_lines[1], u"[#1] Result(value=1)") + self.assertEqual(log_lines[2], u"[#2] Command(name='hget', args=(u'test_logging_middleware_with_another:testmodel:1:hash', 'foo'), kwargs={})") + self.assertEqual(log_lines[3], u"[#2] Result(value=BAZ)") + +if __name__ == '__main__': + unittest.main()