Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Allow Redis fixture to use decode_responses parameter #215

Merged
merged 9 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
Expand All @@ -18,15 +17,15 @@
extensions = [
"m2r2",
"sphinx.ext.autodoc",
'sphinx_autodoc_typehints',
"sphinx_autodoc_typehints",
"sphinx.ext.autosectionlabel",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
source_suffix = ['.rst', '.md']
source_suffix = [".rst", ".md"]

html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]
Expand Down
4 changes: 3 additions & 1 deletion src/pytest_mock_resources/container/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def detect_driver(drivername: Optional[str] = None, async_: bool = False) -> str
if any(Distribution.discover(name="asyncpg")):
return "postgresql+asyncpg"
else:
if any(Distribution.discover(name="psycopg2")) or any(Distribution.discover(name="psycopg2-binary")):
if any(Distribution.discover(name="psycopg2")) or any(
Distribution.discover(name="psycopg2-binary")
):
return "postgresql+psycopg2"

raise ValueError( # pragma: no cover
Expand Down
7 changes: 6 additions & 1 deletion src/pytest_mock_resources/container/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ class RedisConfig(DockerContainerConfig):

name = "redis"

_fields: ClassVar[Iterable] = {"image", "host", "port", "ci_port"}
_fields: ClassVar[Iterable] = {
"image",
"host",
"port",
"ci_port",
}
_fields_defaults: ClassVar[dict] = {
"image": "redis:5.0.7",
"port": 6380,
Expand Down
10 changes: 8 additions & 2 deletions src/pytest_mock_resources/fixture/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def pmr_redis_container(pytestconfig, pmr_redis_config):
yield from get_container(pytestconfig, pmr_redis_config)


def create_redis_fixture(scope="function"):
def create_redis_fixture(scope="function", decode_responses: bool = False):
"""Produce a Redis fixture.

Any number of fixture functions can be created. Under the hood they will all share the same
Expand All @@ -44,6 +44,7 @@ def create_redis_fixture(scope="function"):

Args:
scope (str): The scope of the fixture can be specified by the user, defaults to "function".
decode_responses (bool): Whether to decode the responses from redis.

Raises:
KeyError: If any additional arguments are provided to the function than what is necessary.
Expand All @@ -62,7 +63,12 @@ def _(request, pmr_redis_container, pmr_redis_config):
"The redis fixture currently only supports up to 16 parallel executions"
)

db = redis.Redis(host=pmr_redis_config.host, port=pmr_redis_config.port, db=database_number)
db = redis.Redis(
host=pmr_redis_config.host,
port=pmr_redis_config.port,
db=database_number,
decode_responses=decode_responses,
)
db.flushdb()

Credentials.assign_from_credentials(
Expand Down
191 changes: 144 additions & 47 deletions tests/fixture/test_redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest

from pytest_mock_resources import create_redis_fixture
from pytest_mock_resources.compat import redis

redis_client = create_redis_fixture()
redis_client_decode = create_redis_fixture(decode_responses=True)

client_parameters = [("redis_client", False), ("redis_client_decode", True)]


def _sets_setup(redis_client):
Expand Down Expand Up @@ -42,83 +47,151 @@ def test_custom_connection_url(self, redis_client):
assert value == "bar"


@pytest.mark.parametrize("redis,is_decoded", client_parameters)
class TestStrings:
def test_set(self, redis_client):
def test_set(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

redis_client.set("foo", "bar")
value = redis_client.get("foo").decode("utf-8")
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert value == "bar"

def test_append(self, redis_client):
def test_append(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

redis_client.set("foo", "bar")
redis_client.append("foo", "baz")
value = redis_client.get("foo").decode("utf-8")
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert value == "barbaz"

def test_int_operations(self, redis_client):
def test_int_operations(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

redis_client.set("foo", 1)
redis_client.incr("foo")
value = int(redis_client.get("foo").decode("utf-8"))
assert value == 2
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert int(value) == 2

redis_client.decr("foo")
value = int(redis_client.get("foo").decode("utf-8"))
assert value == 1
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert int(value) == 1

redis_client.incrby("foo", 4)
value = int(redis_client.get("foo").decode("utf-8"))
assert value == 5
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert int(value) == 5

redis_client.decrby("foo", 3)
value = int(redis_client.get("foo").decode("utf-8"))
assert value == 2
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")

assert int(value) == 2

def test_float_operations(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

def test_float_operations(self, redis_client):
redis_client.set("foo", 1.2)
value = float(redis_client.get("foo").decode("utf-8"))
assert value == 1.2
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert float(value) == 1.2

redis_client.incrbyfloat("foo", 4.1)
value = float(redis_client.get("foo").decode("utf-8"))
assert value == 5.3
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert float(value) == 5.3

redis_client.incrbyfloat("foo", -3.1)
value = float(redis_client.get("foo").decode("utf-8"))
assert value == 2.2
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert float(value) == 2.2

def test_multiple_keys(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

def test_multiple_keys(self, redis_client):
test_mapping = {"foo": "bar", "baz": 1, "flo": 1.2}
redis_client.mset(test_mapping)
assert redis_client.get("foo").decode("utf-8") == "bar"
assert int(redis_client.get("baz").decode("utf-8")) == 1
assert float(redis_client.get("flo").decode("utf-8")) == 1.2

def test_querries(self, redis_client):
value = redis_client.get("foo")
if not is_decoded:
value = value.decode("utf-8")
assert value == "bar"

value = redis_client.get("baz")
if not is_decoded:
value = value.decode("utf-8")
assert int(value) == 1

value = redis_client.get("flo")
if not is_decoded:
value = value.decode("utf-8")
assert float(value) == 1.2

def test_querries(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

test_mapping = {"foo1": "bar1", "foo2": "bar2", "flo": "flo"}
redis_client.mset(test_mapping)
foo_keys = redis_client.keys("foo*")
assert b"foo1" in foo_keys
assert b"foo2" in foo_keys
assert b"flo" not in foo_keys
if is_decoded:
assert "foo1" in foo_keys
assert "foo2" in foo_keys
assert "flo" not in foo_keys
else:
assert b"foo1" in foo_keys
assert b"foo2" in foo_keys
assert b"flo" not in foo_keys


@pytest.mark.parametrize("redis,is_decoded", client_parameters)
class TestSets:
def test_sadd(self, redis_client):
def test_sadd(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

_sets_setup(redis_client)
friends_leto = redis_client.smembers("friends:leto")
friends_paul = redis_client.smembers("friends:paul")
assert friends_leto == {b"duncan", b"ghanima"}
assert friends_paul == {b"gurney", b"duncan"}

def test_set_operations(self, redis_client):
if is_decoded:
assert friends_leto == {"duncan", "ghanima"}
assert friends_paul == {"gurney", "duncan"}
else:
assert friends_leto == {b"duncan", b"ghanima"}
assert friends_paul == {b"gurney", b"duncan"}

def test_set_operations(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

_sets_setup(redis_client)
inter = redis_client.sinter("friends:leto", "friends:paul")
assert inter == {b"duncan"}
if is_decoded:
assert inter == {"duncan"}
else:
assert inter == {b"duncan"}

union = redis_client.sunion("friends:leto", "friends:paul")
assert union == {b"ghanima", b"duncan", b"gurney"}
if is_decoded:
assert union == {"ghanima", "duncan", "gurney"}
else:
assert union == {b"ghanima", b"duncan", b"gurney"}

diff = redis_client.sdiff("friends:leto", "friends:paul")
assert diff == {b"ghanima"}
if is_decoded:
assert diff == {"ghanima"}
else:
assert diff == {b"ghanima"}

cardinality_leto = redis_client.scard("friends:leto")
assert cardinality_leto == 2
Expand All @@ -130,42 +203,66 @@ def test_set_operations(self, redis_client):
assert redis_client.sismember("friends:paul", "ghanima")


@pytest.mark.parametrize("redis,is_decoded", client_parameters)
class TestHashes:
def test_hset(self, redis_client):
def test_hset(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

_hash_setup(redis_client)
user = redis_client.hgetall("user")
assert user == {b"name": b"foo", b"age": b"30"}
if is_decoded:
assert user == {"name": "foo", "age": "30"}
else:
assert user == {b"name": b"foo", b"age": b"30"}

def test_hash_operations(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

def test_hash_operations(self, redis_client):
_hash_setup(redis_client)
assert redis_client.hexists("user", "name")
assert redis_client.hexists("user", "age")

keys = redis_client.hkeys("user")
assert keys == [b"name", b"age"]
if is_decoded:
assert keys == ["name", "age"]
else:
assert keys == [b"name", b"age"]

len = redis_client.hlen("user")
assert len == 2

vals = redis_client.hvals("user")
assert vals == [b"foo", b"30"]
if is_decoded:
assert vals == ["foo", "30"]
else:
assert vals == [b"foo", b"30"]


@pytest.mark.parametrize("redis,is_decoded", client_parameters)
class TestLists:
def test_lset(self, redis_client):
def test_lset(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

_list_setup(redis_client)
items = redis_client.lrange("dbs", 0, 4)
assert items == [b"mysql_lite", b"mysql", b"postgres", b"redis", b"mongo"]
if is_decoded:
assert items == ["mysql_lite", "mysql", "postgres", "redis", "mongo"]
else:
assert items == [b"mysql_lite", b"mysql", b"postgres", b"redis", b"mongo"]

def test_list_operations(self, redis, is_decoded, request):
redis_client = request.getfixturevalue(redis)

def test_list_operations(self, redis_client):
_list_setup(redis_client)
assert redis_client.llen("dbs") == 5
assert redis_client.lindex("dbs", 1) == b"mysql"
assert redis_client.lpop("dbs") == b"mysql_lite"
assert redis_client.lindex("dbs", 1) == ("mysql" if is_decoded else b"mysql")
assert redis_client.lpop("dbs") == ("mysql_lite" if is_decoded else b"mysql_lite")

redis_client.rpush("dbs", "RabbitMQ")
assert redis_client.rpop("dbs") == b"RabbitMQ"
assert redis_client.rpop("dbs") == ("RabbitMQ" if is_decoded else b"RabbitMQ")

redis_client.ltrim("dbs", 1, -1)
rest = redis_client.lrange("dbs", 0, -1)
assert rest == [b"postgres", b"redis", b"mongo"]
assert rest == (
["postgres", "redis", "mongo"] if is_decoded else [b"postgres", b"redis", b"mongo"]
)
Loading