Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add signing of cache values. #60

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
Version 0.4.0
-------------

Unreleased

- Add ``secret_key`` argument to ``FileSystemCache``, ``RedisCache``, and
``UWSGICache``. The serialized data is signed with this key.
Without this key, anyone with write access to the cache location (Redis, file
system, or UWSGI cache) can trick cachelib into remote code execution.
:issue:`1`


Version 0.3.0
-------------

Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ packages = find:
package_dir = = src
include_package_data = true
python_requires = >= 3.6
install_requires =
itsdangerous ~= 2.0

[options.packages.find]
where = src
Expand Down
66 changes: 65 additions & 1 deletion src/cachelib/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,84 @@
import base64
import io
import pickle
import typing as _t

import itsdangerous


class _Base85Pickler:
"""
Pickles Base85 safe. To allow multiple reads from the same file, encoded output must
not contain any newlines. Base85 has less overhead than Base64.
"""

@staticmethod
def dumps(obj: _t.Any) -> bytes:
return base64.b85encode(pickle.dumps(obj))

@staticmethod
def loads(data: _t.AnyStr) -> bytes:
return pickle.loads(base64.b85decode(data))


class BaseCache:
"""Baseclass for the cache systems. All the cache systems implement this
API or a superset of it.

:param default_timeout: the default timeout (in seconds) that is used if
no timeout is specified on :meth:`set`. A timeout
of 0 indicates that the cache never expires.

:param secret_key: Key to sign cache entries with.

.. versionadded:: 0.4.0

"""

def __init__(self, default_timeout=300):
def __init__(
self,
default_timeout=300,
*,
secret_key: _t.Optional[_t.Union[_t.AnyStr, _t.Collection[_t.AnyStr]]] = None
):
self.default_timeout = default_timeout
if secret_key is not None:
self.__signed_serializer = itsdangerous.Serializer(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why double-underscore names? That makes subclassing less convenient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subclass are only supposed to access _dump/load(s), but sure, I can remove a single underscore.

secret_key, serializer=_Base85Pickler
)
else:
self.__signed_serializer = None

def _normalize_timeout(self, timeout):
if timeout is None:
timeout = self.default_timeout
return timeout

def _dumps(self, to_serialize: _t.Any) -> bytes:
buf = io.BytesIO()
self._dump(to_serialize, buf)
return buf.getvalue()

def _dump(self, to_serialize: _t.Any, file: _t.IO[bytes]) -> None:
if self.__signed_serializer:
self.__signed_serializer.dump(to_serialize, file)
file.write(b"\n")
else:
return pickle.dump(to_serialize, file)

def _loads(self, serialized: bytes) -> _t.Any:
buf = io.BytesIO(serialized)
return self._unpack(buf)

def _load(self, file: _t.IO[bytes]) -> _t.Any:
if self.__signed_serializer:
try:
read = file.readline()[:-1]
return self.__signed_serializer.loads(read)
except (itsdangerous.BadSignature, pickle.UnpicklingError):
return None
return pickle.load(file)

def get(self, key):
"""Look up key in the cache and return the value for it.

Expand Down
40 changes: 30 additions & 10 deletions src/cachelib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle
import tempfile
import typing as _t
from hashlib import md5
from pathlib import Path
from time import time
Expand All @@ -24,15 +25,32 @@ class FileSystemCache(BaseCache):
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param mode: the file mode wanted for the cache files, default 0600

:param secret_key: Key to sign cache entries with.

.. warning::
Without a secret key or in case the secret key is not secret anymore,
anyone with write access to the cache directory can trick your program
into executing arbitrary code.

.. versionadded:: 0.4.0
"""

#: used for temporary files by the FileSystemCache
_fs_transaction_suffix = ".__wz_cache"
#: keep amount of files in a cache element
_fs_count_file = "__wz_cache_count"

def __init__(self, cache_dir, threshold=500, default_timeout=300, mode=0o600):
BaseCache.__init__(self, default_timeout)
def __init__(
self,
cache_dir,
threshold=500,
default_timeout=300,
mode=0o600,
*,
secret_key: _t.Optional[_t.Union[_t.AnyStr, _t.Collection[_t.AnyStr]]] = None
):
BaseCache.__init__(self, default_timeout, secret_key=secret_key)
self._path = cache_dir
self._threshold = threshold
self._mode = mode
Expand Down Expand Up @@ -87,7 +105,7 @@ def _remove_expired(self, now):
for fname in self._list_dir():
try:
with open(fname, "rb") as f:
expires = pickle.load(f)
expires = self._load(f)
if expires != 0 and expires < now:
os.remove(fname)
self._update_count(delta=-1)
Expand All @@ -103,7 +121,7 @@ def _remove_older(self):
for fname in self._list_dir():
try:
with open(fname, "rb") as f:
exp_fname_tuples.append((pickle.load(f), fname))
exp_fname_tuples.append((self._load(f), fname))
except (OSError, EOFError):
logging.warning(
"Exception raised while handling cache file '%s'",
Expand Down Expand Up @@ -160,9 +178,11 @@ def get(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
if pickle_time == 0 or pickle_time >= time():
return pickle.load(f)
pickle_time = self._load(f)
if pickle_time is not None and (
pickle_time == 0 or pickle_time >= time()
):
return self._load(f)
else:
return None
except (OSError, EOFError, pickle.PickleError):
Expand Down Expand Up @@ -196,8 +216,8 @@ def set(self, key, value, timeout=None, mgmt_element=False):
suffix=self._fs_transaction_suffix, dir=self._path
)
with os.fdopen(fd, "wb") as f:
pickle.dump(timeout, f, 1)
pickle.dump(value, f, pickle.HIGHEST_PROTOCOL)
self._dump(timeout, f)
self._dump(value, f)
os.replace(tmp, filename)
os.chmod(filename, self._mode)
fsize = Path(filename).stat().st_size
Expand Down Expand Up @@ -230,7 +250,7 @@ def has(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
pickle_time = self._load(f)
if pickle_time == 0 or pickle_time >= time():
return True
else:
Expand Down
20 changes: 18 additions & 2 deletions src/cachelib/redis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import typing as _t

from cachelib.base import BaseCache

Expand All @@ -21,6 +22,14 @@ class RedisCache(BaseCache):
specified on :meth:`~BaseCache.set`. A timeout of
0 indicates that the cache never expires.
:param key_prefix: A prefix that should be added to all keys.
:param secret_key: Key to sign cache entries with.

.. warning::
Without a secret key or in case the secret key is not secret anymore,
anyone with write access to the redis instance can trick your program
into executing arbitrary code.

.. versionadded:: 0.4.0

Any additional keyword arguments will be passed to ``redis.Redis``.
"""
Expand All @@ -33,9 +42,11 @@ def __init__(
db=0,
default_timeout=300,
key_prefix=None,
**kwargs
*,
secret_key: _t.Optional[_t.Union[_t.AnyStr, _t.Collection[_t.AnyStr]]] = None,
**kwargs,
):
BaseCache.__init__(self, default_timeout)
BaseCache.__init__(self, default_timeout, secret_key=secret_key)
if host is None:
raise ValueError("RedisCache host parameter may not be None")
if isinstance(host, str):
Expand All @@ -51,6 +62,7 @@ def __init__(
else:
self._client = host
self.key_prefix = key_prefix or ""
self._has_secret_key = secret_key is not None

def _normalize_timeout(self, timeout):
timeout = BaseCache._normalize_timeout(self, timeout)
Expand All @@ -62,6 +74,8 @@ def dump_object(self, value):
"""Dumps an object into a string for redis. By default it serializes
integers as regular string and pickle dumps everything else.
"""
if self._has_secret_key:
return self._dumps(value)
t = type(value)
if isinstance(t, int):
return str(value).encode("ascii")
Expand All @@ -71,6 +85,8 @@ def load_object(self, value):
"""The reversal of :meth:`dump_object`. This might be called with
None.
"""
if self._has_secret_key:
return self._loads(value)
if value is None:
return None
if value.startswith(b"!"):
Expand Down
26 changes: 20 additions & 6 deletions src/cachelib/uwsgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
import platform
import typing as _t

from cachelib.base import BaseCache

Expand All @@ -17,10 +17,24 @@ class UWSGICache(BaseCache):
means uWSGI will cache in the local instance. If the cache is in the
same instance as the werkzeug app, you only have to provide the name of
the cache.
:param secret_key: Key to sign cache entries with.

.. warning::
Without a secret key or in case the secret key is not secret anymore,
anyone with write access to the uWSGI cache can trick your program
into executing arbitrary code.

.. versionadded:: 0.4.0
"""

def __init__(self, default_timeout=300, cache=""):
BaseCache.__init__(self, default_timeout)
def __init__(
self,
default_timeout=300,
cache="",
*,
secret_key: _t.Optional[_t.Union[_t.AnyStr, _t.Collection[_t.AnyStr]]] = None,
):
BaseCache.__init__(self, default_timeout, secret_key=secret_key)

if platform.python_implementation() == "PyPy":
raise RuntimeError(
Expand All @@ -43,19 +57,19 @@ def get(self, key):
rv = self._uwsgi.cache_get(key, self.cache)
if rv is None:
return
return pickle.loads(rv)
return self._loads(rv)

def delete(self, key):
return self._uwsgi.cache_del(key, self.cache)

def set(self, key, value, timeout=None):
return self._uwsgi.cache_update(
key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache
key, self._dumps(value), self._normalize_timeout(timeout), self.cache
)

def add(self, key, value, timeout=None):
return self._uwsgi.cache_set(
key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache
key, self._dumps(value), self._normalize_timeout(timeout), self.cache
)

def clear(self):
Expand Down
36 changes: 36 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from conftest import TestData
from conftest import under_uwsgi

from cachelib import MemcachedCache
from cachelib import SimpleCache


class CommonTests(TestData):
"""A base set of tests to be run for all cache types"""
Expand Down Expand Up @@ -75,3 +78,36 @@ def test_expiration(self):
for k, v in self.sample_pairs.items():
assert cache.get(f"{k}-t0") == v
assert not cache.get(f"{k}-t1")

def test_signed_set_get(self):
if isinstance(self.cache_factory(), (MemcachedCache, SimpleCache)):
pytest.skip("Simple and MemcachedCache do not support signing.")
cache = self.cache_factory(secret_key="not very secret")

for k, v in self.sample_pairs.items():
assert cache.set(k, v)
assert cache.get(k) == v

def test_signed_does_not_get_unsigned(self):
unsigned_cache = self.cache_factory()
if isinstance(unsigned_cache, (MemcachedCache, SimpleCache)):
pytest.skip("Simple and MemcachedCache do not support signing.")

signed_cache = self.cache_factory(secret_key="not very secret")

for k, v in self.sample_pairs.items():
assert unsigned_cache.set(k, v)
assert signed_cache.get(k) is None

def test_signed_does_not_get_wrong_key(self):
if isinstance(self.cache_factory(), (MemcachedCache, SimpleCache)):
pytest.skip("Simple and MemcachedCache do not support signing.")

signed_cache1 = self.cache_factory(secret_key="not very secret")
signed_cache2 = self.cache_factory(secret_key="another not secret value")

for k, v in self.sample_pairs.items():
assert signed_cache1.set(k, v)
assert signed_cache2.get(k) is None
assert signed_cache2.set(k, v)
assert signed_cache1.get(k) is None