Skip to content

Commit

Permalink
Make S3Storage to check staleness of all cache files with set inter…
Browse files Browse the repository at this point in the history
…val. (#182)

* Make `S3Storage` to check staleness of all cache files with set interval.

* Change description for `test_s3storage_key_error`

* Run checks.

* Remove unnecessary try/except

* Update CHANGELOG.md

* Increase `ttl` for `test_s3storage_expired`

* Update storages.md

Add `check_ttl_every` example for `S3Storage`

---------

Co-authored-by: Kar Petrosyan <[email protected]>
  • Loading branch information
umnovI and karpetrosyan authored Feb 12, 2024
1 parent 6fcc25b commit 9d8db47
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Make `S3Storage` to check staleness of all cache files with set interval. (#182)
- Fix an issue where an empty file in `FileCache` could cause a parsing error. (#181)
- Support caching for `POST` and other HTTP methods. (#183)

Expand Down
10 changes: 10 additions & 0 deletions docs/advanced/storages.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,5 +284,15 @@ storage = hishel.S3Storage(ttl=3600)
If you do this, `Hishel` will delete any stored responses whose ttl has expired.
In this example, the stored responses were limited to 1 hour.

#### Check ttl every

In order to avoid excessive memory utilization, `Hishel` must periodically clean the old responses, or responses that are not being used and should be deleted from the cache.
It clears the cache by default every minute, but you may change the interval directly with the `check_ttl_every` argument.

Example:

```python
import hishel

storage = hishel.S3Storage(check_ttl_every=600) # check every 600s (10m)
```
19 changes: 14 additions & 5 deletions hishel/_async/_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ class AsyncS3Storage(AsyncBaseStorage):
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
:param client: A client for S3, defaults to None
:type client: tp.Optional[tp.Any], optional
"""
Expand All @@ -445,6 +448,7 @@ def __init__(
bucket_name: str,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
client: tp.Optional[tp.Any] = None,
) -> None:
super().__init__(serializer, ttl)
Expand All @@ -460,7 +464,12 @@ def __init__(

self._bucket_name = bucket_name
client = client or boto3.client("s3")
self._s3_manager = AsyncS3Manager(client=client, bucket_name=bucket_name, is_binary=self._serializer.is_binary)
self._s3_manager = AsyncS3Manager(
client=client,
bucket_name=bucket_name,
is_binary=self._serializer.is_binary,
check_ttl_every=check_ttl_every,
)
self._lock = AsyncLock()

async def store(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
Expand All @@ -481,7 +490,7 @@ async def store(self, key: str, response: Response, request: Request, metadata:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._s3_manager.write_to(path=key, data=serialized)

await self._remove_expired_caches()
await self._remove_expired_caches(key)

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Expand All @@ -493,7 +502,7 @@ async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
:rtype: tp.Optional[StoredResponse]
"""

await self._remove_expired_caches()
await self._remove_expired_caches(key)
async with self._lock:
try:
return self._serializer.loads(await self._s3_manager.read_from(path=key))
Expand All @@ -503,10 +512,10 @@ async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def aclose(self) -> None: # pragma: no cover
return

async def _remove_expired_caches(self) -> None:
async def _remove_expired_caches(self, key: str) -> None:
if self._ttl is None:
return

async with self._lock:
converted_ttl = float_seconds_to_int_milliseconds(self._ttl)
await self._s3_manager.remove_expired(ttl=converted_ttl)
await self._s3_manager.remove_expired(ttl=converted_ttl, key=key)
34 changes: 28 additions & 6 deletions hishel/_s3.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import time
import typing as tp
from datetime import datetime, timedelta, timezone

from anyio import to_thread
from botocore.exceptions import ClientError


class S3Manager:
def __init__(self, client: tp.Any, bucket_name: str, is_binary: bool = False):
def __init__(
self, client: tp.Any, bucket_name: str, check_ttl_every: tp.Union[int, float], is_binary: bool = False
):
self._client = client
self._bucket_name = bucket_name
self._is_binary = is_binary
self._last_cleaned = time.monotonic()
self._check_ttl_every = check_ttl_every

def write_to(self, path: str, data: tp.Union[bytes, str]) -> None:
path = "hishel-" + path
Expand All @@ -31,7 +37,21 @@ def read_from(self, path: str) -> tp.Union[bytes, str]:

return tp.cast(str, content.decode("utf-8"))

def remove_expired(self, ttl: int) -> None:
def remove_expired(self, ttl: int, key: str) -> None:
path = "hishel-" + key

if time.monotonic() - self._last_cleaned < self._check_ttl_every:
try:
response = self._client.get_object(Bucket=self._bucket_name, Key=path)
if datetime.now(timezone.utc) - response["LastModified"] > timedelta(milliseconds=ttl):
self._client.delete_object(Bucket=self._bucket_name, Key=path)
return
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
return
raise e

self._last_cleaned = time.monotonic()
for obj in self._client.list_objects(Bucket=self._bucket_name).get("Contents", []):
if not obj["Key"].startswith("hishel-"): # pragma: no cover
continue
Expand All @@ -41,14 +61,16 @@ def remove_expired(self, ttl: int) -> None:


class AsyncS3Manager:
def __init__(self, client: tp.Any, bucket_name: str, is_binary: bool = False):
self._sync_manager = S3Manager(client, bucket_name, is_binary)
def __init__(
self, client: tp.Any, bucket_name: str, check_ttl_every: tp.Union[int, float], is_binary: bool = False
):
self._sync_manager = S3Manager(client, bucket_name, check_ttl_every, is_binary)

async def write_to(self, path: str, data: tp.Union[bytes, str]) -> None:
return await to_thread.run_sync(self._sync_manager.write_to, path, data)

async def read_from(self, path: str) -> tp.Union[bytes, str]:
return await to_thread.run_sync(self._sync_manager.read_from, path)

async def remove_expired(self, ttl: int) -> None:
return await to_thread.run_sync(self._sync_manager.remove_expired, ttl)
async def remove_expired(self, ttl: int, key: str) -> None:
return await to_thread.run_sync(self._sync_manager.remove_expired, ttl, key)
19 changes: 14 additions & 5 deletions hishel/_sync/_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ class S3Storage(BaseStorage):
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
:param client: A client for S3, defaults to None
:type client: tp.Optional[tp.Any], optional
"""
Expand All @@ -445,6 +448,7 @@ def __init__(
bucket_name: str,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
client: tp.Optional[tp.Any] = None,
) -> None:
super().__init__(serializer, ttl)
Expand All @@ -460,7 +464,12 @@ def __init__(

self._bucket_name = bucket_name
client = client or boto3.client("s3")
self._s3_manager = S3Manager(client=client, bucket_name=bucket_name, is_binary=self._serializer.is_binary)
self._s3_manager = S3Manager(
client=client,
bucket_name=bucket_name,
is_binary=self._serializer.is_binary,
check_ttl_every=check_ttl_every,
)
self._lock = Lock()

def store(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
Expand All @@ -481,7 +490,7 @@ def store(self, key: str, response: Response, request: Request, metadata: Metada
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
self._s3_manager.write_to(path=key, data=serialized)

self._remove_expired_caches()
self._remove_expired_caches(key)

def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Expand All @@ -493,7 +502,7 @@ def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
:rtype: tp.Optional[StoredResponse]
"""

self._remove_expired_caches()
self._remove_expired_caches(key)
with self._lock:
try:
return self._serializer.loads(self._s3_manager.read_from(path=key))
Expand All @@ -503,10 +512,10 @@ def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
def close(self) -> None: # pragma: no cover
return

def _remove_expired_caches(self) -> None:
def _remove_expired_caches(self, key: str) -> None:
if self._ttl is None:
return

with self._lock:
converted_ttl = float_seconds_to_int_milliseconds(self._ttl)
self._s3_manager.remove_expired(ttl=converted_ttl)
self._s3_manager.remove_expired(ttl=converted_ttl, key=key)
43 changes: 41 additions & 2 deletions tests/_async/test_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def test_filestorage_expired(use_temp_dir):
@pytest.mark.asyncio
async def test_s3storage_expired(use_temp_dir, s3):
boto3.client("s3").create_bucket(Bucket="testBucket")
storage = AsyncS3Storage(bucket_name="testBucket", ttl=1)
storage = AsyncS3Storage(bucket_name="testBucket", ttl=3)

first_request = Request(b"GET", "https://example.com")
second_request = Request(b"GET", "https://anotherexample.com")
Expand All @@ -198,7 +198,7 @@ async def test_s3storage_expired(use_temp_dir, s3):
await storage.store(first_key, response=response, request=first_request, metadata=dummy_metadata)
assert await storage.retrieve(first_key) is not None

await asleep(1)
await asleep(3)
await storage.store(second_key, response=response, request=second_request, metadata=dummy_metadata)

assert await storage.retrieve(first_key) is None
Expand Down Expand Up @@ -316,3 +316,42 @@ async def test_filestorage_empty_file_exception(use_temp_dir):
file.truncate(0)
assert os.path.getsize(filedir) == 0
assert await storage.retrieve(key) is None


@pytest.mark.anyio
async def test_s3storage_timer(use_temp_dir, s3):
boto3.client("s3").create_bucket(Bucket="testBucket")
storage = AsyncS3Storage(bucket_name="testBucket", ttl=5, check_ttl_every=5)

first_request = Request(b"GET", "https://example.com")
second_request = Request(b"GET", "https://anotherexample.com")

first_key = generate_key(first_request)
second_key = generate_key(second_request)

response = Response(200, headers=[], content=b"test")
await response.aread()

await storage.store(first_key, response=response, request=first_request, metadata=dummy_metadata)
assert await storage.retrieve(first_key) is not None
await asleep(3)
assert await storage.retrieve(first_key) is not None
await storage.store(second_key, response=response, request=second_request, metadata=dummy_metadata)
assert await storage.retrieve(second_key) is not None
await asleep(2)
assert await storage.retrieve(first_key) is None
assert await storage.retrieve(second_key) is not None
await asleep(3)
assert await storage.retrieve(second_key) is None


@pytest.mark.anyio
async def test_s3storage_key_error(use_temp_dir, s3):
"""Triggers `NoSuchKey` error."""

boto3.client("s3").create_bucket(Bucket="testBucket")
storage = AsyncS3Storage(bucket_name="testBucket", ttl=60)
first_request = Request(b"GET", "https://example.com")
first_key = generate_key(first_request)

assert await storage.retrieve(first_key) is None
43 changes: 41 additions & 2 deletions tests/_sync/test_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_filestorage_expired(use_temp_dir):

def test_s3storage_expired(use_temp_dir, s3):
boto3.client("s3").create_bucket(Bucket="testBucket")
storage = S3Storage(bucket_name="testBucket", ttl=1)
storage = S3Storage(bucket_name="testBucket", ttl=3)

first_request = Request(b"GET", "https://example.com")
second_request = Request(b"GET", "https://anotherexample.com")
Expand All @@ -198,7 +198,7 @@ def test_s3storage_expired(use_temp_dir, s3):
storage.store(first_key, response=response, request=first_request, metadata=dummy_metadata)
assert storage.retrieve(first_key) is not None

sleep(1)
sleep(3)
storage.store(second_key, response=response, request=second_request, metadata=dummy_metadata)

assert storage.retrieve(first_key) is None
Expand Down Expand Up @@ -316,3 +316,42 @@ def test_filestorage_empty_file_exception(use_temp_dir):
file.truncate(0)
assert os.path.getsize(filedir) == 0
assert storage.retrieve(key) is None



def test_s3storage_timer(use_temp_dir, s3):
boto3.client("s3").create_bucket(Bucket="testBucket")
storage = S3Storage(bucket_name="testBucket", ttl=5, check_ttl_every=5)

first_request = Request(b"GET", "https://example.com")
second_request = Request(b"GET", "https://anotherexample.com")

first_key = generate_key(first_request)
second_key = generate_key(second_request)

response = Response(200, headers=[], content=b"test")
response.read()

storage.store(first_key, response=response, request=first_request, metadata=dummy_metadata)
assert storage.retrieve(first_key) is not None
sleep(3)
assert storage.retrieve(first_key) is not None
storage.store(second_key, response=response, request=second_request, metadata=dummy_metadata)
assert storage.retrieve(second_key) is not None
sleep(2)
assert storage.retrieve(first_key) is None
assert storage.retrieve(second_key) is not None
sleep(3)
assert storage.retrieve(second_key) is None



def test_s3storage_key_error(use_temp_dir, s3):
"""Triggers `NoSuchKey` error."""

boto3.client("s3").create_bucket(Bucket="testBucket")
storage = S3Storage(bucket_name="testBucket", ttl=60)
first_request = Request(b"GET", "https://example.com")
first_key = generate_key(first_request)

assert storage.retrieve(first_key) is None

0 comments on commit 9d8db47

Please sign in to comment.