Skip to content

Commit

Permalink
NEW: Make cached decorator time_to_live also accept a callable
Browse files Browse the repository at this point in the history
The callable needs to be of the same `color` (sync/async) as the cached decorator.
This was done to be consistent with the cache instance, but could be changed
together when addressing issue #16

closes #30
  • Loading branch information
Sergio Castillo authored and eigenein committed Jun 14, 2023
1 parent d19e4f3 commit 8931153
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 6 deletions.
10 changes: 7 additions & 3 deletions cachetory/decorators/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def cached(
cache: Union[Cache[ValueT, WireT], Callable[..., Awaitable[Cache[ValueT, WireT]]]], # no way to use `P` here
*,
make_key: Callable[..., str] = shared.make_default_key, # no way to use `P` here
time_to_live: Optional[timedelta] = None,
time_to_live: Optional[Union[timedelta, Callable[..., Awaitable[timedelta]]]] = None,
if_not_exists: bool = False,
) -> Callable[[Callable[P, Awaitable[ValueT]]], Callable[P, Awaitable[ValueT]]]:
"""
Expand All @@ -32,18 +32,22 @@ def cached(
and the rest of the arguments next to it.
make_key: callable to generate a custom cache key per each call.
if_not_exists: controls concurrent sets: if `True` – avoids overwriting a cached value.
time_to_live: cached value expiration time.
time_to_live:
cached value expiration time or async callable that returns the expiration time.
The callable needs to accept keyword arguments, and it is given the cache key to
compute the expiration time.
"""

def wrap(callable_: Callable[P, Awaitable[ValueT]]) -> Callable[P, Awaitable[ValueT]]:
@wraps(callable_)
async def cached_callable(*args: P.args, **kwargs: P.kwargs) -> ValueT:
cache_ = await cache(callable_, *args, **kwargs) if callable(cache) else cache
key_ = make_key(callable_, *args, **kwargs)
time_to_live_ = await time_to_live(key=key_) if callable(time_to_live) else time_to_live
value = await cache_.get(key_)
if value is None:
value = await callable_(*args, **kwargs)
await cache_.set(key_, value, time_to_live=time_to_live, if_not_exists=if_not_exists)
await cache_.set(key_, value, time_to_live=time_to_live_, if_not_exists=if_not_exists)
return value

return cached_callable
Expand Down
10 changes: 7 additions & 3 deletions cachetory/decorators/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def cached(
cache: Union[Cache[ValueT, WireT], Callable[..., Cache[ValueT, WireT]]], # no way to use `P` here
*,
make_key: Callable[..., str] = shared.make_default_key, # no way to use `P` here
time_to_live: Optional[timedelta] = None,
time_to_live: Optional[Union[timedelta, Callable[..., timedelta]]] = None,
if_not_exists: bool = False,
) -> Callable[[Callable[P, ValueT]], Callable[P, ValueT]]:
"""
Expand All @@ -32,19 +32,23 @@ def cached(
and the rest of the arguments next to it.
make_key: callable to generate a custom cache key per each call.
if_not_exists: controls concurrent sets: if `True` – avoids overwriting a cached value.
time_to_live: cached value expiration time.
time_to_live:
cached value expiration time or callable that returns the expiration time.
The callable needs to accept keyword arguments, and it is given the cache key to
compute the expiration time.
"""

def wrap(callable_: Callable[P, ValueT]) -> Callable[P, ValueT]:
@wraps(callable_)
def cached_callable(*args: P.args, **kwargs: P.kwargs) -> ValueT:
cache_ = cache(callable_, *args, **kwargs) if callable(cache) else cache
key_ = make_key(callable_, *args, **kwargs)
time_to_live_ = time_to_live(key=key_) if callable(time_to_live) else time_to_live
try:
value = cache_[key_]
except KeyError:
value = callable_(*args, **kwargs)
cache_.set(key_, value, time_to_live=time_to_live, if_not_exists=if_not_exists)
cache_.set(key_, value, time_to_live=time_to_live_, if_not_exists=if_not_exists)
return value

return cached_callable
Expand Down
37 changes: 37 additions & 0 deletions tests/decorators/test_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import timedelta
from typing import Any
from unittest import mock

from pytest import fixture, mark

Expand Down Expand Up @@ -57,3 +59,38 @@ async def expensive_function(_: int) -> int:

assert cache._backend.size == 1 # type: ignore
assert cache_2._backend.size == 1 # type: ignore


async def test_time_to_live_accepts_callable(cache: Cache[int, int]):
expected_time_to_live = timedelta(seconds=42)

async def ttl(*args: Any, **kwargs: Any) -> timedelta:
return expected_time_to_live

@cached(cache, time_to_live=ttl)
async def expensive_function() -> int:
return 1

with mock.patch.object(cache, "set", wraps=cache.set) as m_set:
assert await expensive_function() == 1

# time_to_live is correctly forwarded to cache
m_set.assert_called_with(mock.ANY, mock.ANY, time_to_live=expected_time_to_live, if_not_exists=mock.ANY)


async def test_time_to_live_callable_depending_on_key(cache: Cache[int, int]):
"""time_to_live accepts the key as a keyword argument, allowing for different expirations."""

async def ttl(key: str) -> timedelta:
if "a=a" in key:
return timedelta(seconds=42)
return timedelta(seconds=1)

@cached(cache, time_to_live=ttl)
async def expensive_function(**kwargs: Any) -> int:
return 1

with mock.patch.object(cache, "set", wraps=cache.set) as m_set:
assert await expensive_function(a="a") == 1

m_set.assert_called_with(mock.ANY, mock.ANY, time_to_live=timedelta(seconds=42), if_not_exists=mock.ANY)
39 changes: 39 additions & 0 deletions tests/decorators/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from datetime import timedelta
from typing import Any
from unittest import mock

from pytest import fixture

from cachetory.backends.sync import MemoryBackend
Expand Down Expand Up @@ -50,3 +54,38 @@ def expensive_function(_: int) -> int:

assert cache._backend.size == 1 # type: ignore
assert cache_2._backend.size == 1 # type: ignore


def test_time_to_live_accepts_callable(cache: Cache[int, int]):
expected_time_to_live = timedelta(seconds=42)

def ttl(*args: Any, **kwargs: Any) -> timedelta:
return expected_time_to_live

@cached(cache, time_to_live=ttl)
def expensive_function() -> int:
return 1

with mock.patch.object(cache, "set", wraps=cache.set) as m_set:
assert expensive_function() == 1

# time_to_live is correctly forwarded to cache
m_set.assert_called_with(mock.ANY, mock.ANY, time_to_live=expected_time_to_live, if_not_exists=mock.ANY)


def test_time_to_live_callable_depending_on_key(cache: Cache[int, int]):
"""time_to_live accepts the key as a keyword argument, allowing for different expirations."""

def ttl(key: str) -> timedelta:
if "a=a" in key:
return timedelta(seconds=42)
return timedelta(seconds=1)

@cached(cache, time_to_live=ttl)
def expensive_function(**kwargs: Any) -> int:
return 1

with mock.patch.object(cache, "set", wraps=cache.set) as m_set:
assert expensive_function(a="a") == 1

m_set.assert_called_with(mock.ANY, mock.ANY, time_to_live=timedelta(seconds=42), if_not_exists=mock.ANY)

0 comments on commit 8931153

Please sign in to comment.