Skip to content

Commit

Permalink
Use performance fixed branch for aioredis
Browse files Browse the repository at this point in the history
  • Loading branch information
tim.reichard committed Mar 8, 2021
1 parent 3a8a09b commit 3732c15
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 52 deletions.
6 changes: 6 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ History
=======


v0.12.0 (2021-03-08)
-----------------------

* Use aioredis transactions performance fixed branch (sean/aioredis-redis-py-compliance) instead of version 1.3.1.


v0.11.7 (2021-03-01)
-----------------------

Expand Down
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ async def main():
config = {'redis_primary_endpoint': 'your-redis-endpoint'}
redis = Redis(config=config, use_json=True, expire=60, use_hashkey=False)

# since aioredis.create_redis_pool is a coroutine we need to instantiate the object within an async function
if redis.pool_task is not None:
redis.pool = await redis.pool_task

# we can override the global expire and since we are using json the cache_value will be converted to json
await redis.set_one_item(cache_key='aioradio', cache_value={'a': 'alpha', 'number': 123}, expire=2)

Expand Down
88 changes: 47 additions & 41 deletions aioradio/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import aioredis
import orjson
from fakeredis.aioredis import create_redis_pool as fake_redis_pool

HASH_ALGO_MAP = {
'SHA1': hashlib.sha1,
Expand All @@ -34,7 +33,6 @@ class Redis:

config: Dict[str, Any] = dataclass_field(default_factory=dict)
pool: aioredis.Redis = dataclass_field(init=False, repr=False)
pool_task: asyncio.coroutine = None

# Set the redis pool min and max connections size
pool_minsize: int = 5
Expand All @@ -51,24 +49,9 @@ class Redis:
# retrieve value letting this class convert from json set use_json = True.
use_json: bool = True

# used exclusively for pytest
fakeredis: bool = False

def __post_init__(self):
if self.fakeredis:
self.pool = asyncio.get_event_loop().run_until_complete(fake_redis_pool())
else:
primary_endpoint = f'redis://{self.config["redis_primary_endpoint"]}'
loop = asyncio.get_event_loop()
if loop and loop.is_running():
self.pool_task = loop.create_task(
aioredis.create_redis_pool(primary_endpoint, minsize=self.pool_minsize, maxsize=self.pool_maxsize))
else:
self.pool = loop.run_until_complete(
aioredis.create_redis_pool(primary_endpoint, minsize=self.pool_minsize, maxsize=self.pool_maxsize))

def __del__(self):
self.pool.close()
primary_endpoint = self.config["redis_primary_endpoint"]
self.pool = aioredis.Redis(host=primary_endpoint)

async def get(self, key: str, use_json: bool=None, encoding: Union[str, None]='utf-8') -> Any:
"""Check if an item is cached in redis.
Expand All @@ -85,10 +68,13 @@ async def get(self, key: str, use_json: bool=None, encoding: Union[str, None]='u
if use_json is None:
use_json = self.use_json

value = await self.pool.get(key, encoding=encoding)
value = await self.pool.get(key)

if value is not None and use_json:
value = orjson.loads(value)
if value is not None:
if encoding is not None:
value = value.decode(encoding)
if use_json:
value = orjson.loads(value)

return value

Expand All @@ -107,12 +93,18 @@ async def mget(self, items: List[str], use_json: bool=None, encoding: Union[str,
if use_json is None:
use_json = self.use_json

values = await self.pool.mget(*items, encoding=encoding)
values = await self.pool.mget(*items)

if use_json:
values = [orjson.loads(val) if val is not None else None for val in values]
results = []
for val in values:
if val is not None:
if encoding is not None:
val = val.decode(encoding)
if use_json:
val = orjson.loads(val)
results.append(val)

return values
return results

async def set(self, key: str, value: str, expire: int=None, use_json: bool=None) -> int:
"""Set one key-value pair in redis.
Expand All @@ -136,7 +128,7 @@ async def set(self, key: str, value: str, expire: int=None, use_json: bool=None)
if use_json:
value = orjson.dumps(value)

return await self.pool.set(key, value, expire=expire)
return await self.pool.set(key, value, ex=expire)

async def delete(self, key: str) -> int:
"""Delete key from redis.
Expand Down Expand Up @@ -166,10 +158,13 @@ async def hget(self, key: str, field: str, use_json: bool=None, encoding: Union[
if use_json is None:
use_json = self.use_json

value = await self.pool.hget(key, field, encoding=encoding)
value = await self.pool.hget(key, field)

if value is not None and use_json:
value = orjson.loads(value)
if value is not None:
if encoding is not None:
value = value.decode(encoding)
if use_json:
value = orjson.loads(value)

return value

Expand All @@ -190,8 +185,10 @@ async def hmget(self, key: str, fields: List[str], use_json: bool=None, encoding
use_json = self.use_json

items = {}
for index, value in enumerate(await self.pool.hmget(key, *fields, encoding=encoding)):
for index, value in enumerate(await self.pool.hmget(key, *fields)):
if value is not None:
if encoding is not None:
value = value.decode(encoding)
if use_json:
value = orjson.loads(value)
items[fields[index]] = value
Expand All @@ -216,13 +213,15 @@ async def hmget_many(self, keys: List[str], fields: List[str], use_json: bool=No

pipeline = self.pool.pipeline()
for key in keys:
pipeline.hmget(key, *fields, encoding=encoding)
pipeline.hmget(key, *fields)

results = []
for values in await pipeline.execute():
items = {}
for index, value in enumerate(values):
if value is not None:
if encoding is not None:
value = value.decode(encoding)
if use_json:
value = orjson.loads(value)
items[fields[index]] = value
Expand All @@ -246,8 +245,12 @@ async def hgetall(self, key: str, use_json: bool=None, encoding: Union[str, None
use_json = self.use_json

items = {}
for hash_key, value in (await self.pool.hgetall(key, encoding=encoding)).items():
for hash_key, value in (await self.pool.hgetall(key)).items():
if encoding is not None:
hash_key = hash_key.decode(encoding)
if value is not None:
if encoding is not None:
value = value.decode(encoding)
if use_json:
value = orjson.loads(value)
items[hash_key] = value
Expand All @@ -271,16 +274,20 @@ async def hgetall_many(self, keys: List[str], use_json: bool=None, encoding: Uni

pipeline = self.pool.pipeline()
for key in keys:
pipeline.hgetall(key, encoding=encoding)
pipeline.hgetall(key)

results = []
for item in await pipeline.execute():
items = {}
for key, value in item.items():
if encoding is not None:
key = key.decode(encoding)
if value is not None:
if encoding is not None:
value = value.decode(encoding)
if use_json:
value = orjson.loads(value)
items[key] = value
items[key] = value
results.append(items)

return results
Expand Down Expand Up @@ -308,9 +315,9 @@ async def hset(self, key: str, field: str, value: str, use_json: bool=None, expi
if use_json:
value = orjson.dumps(value)

pipeline = self.pool.multi_exec()
pipeline = self.pool.pipeline()
pipeline.hset(key, field, value)
pipeline.expire(key, timeout=expire)
pipeline.expire(key, time=expire)
result, _ = await pipeline.execute()

return result
Expand All @@ -320,7 +327,7 @@ async def hmset(self, key: str, items: Dict[str, Any], use_json: bool=None, expi
Args:
key (str): cache key
items (List[str, Any]): list of redis hash key-value pairs
items (Dict[str, Any]): list of redis hash key-value pairs
use_json (bool, optional): set object to json before writing to cache. Defaults to None.
expire (int, optional): cache expiration. Defaults to None.
Expand All @@ -339,10 +346,9 @@ async def hmset(self, key: str, items: Dict[str, Any], use_json: bool=None, expi
items = modified_items

pipeline = self.pool.pipeline()
pipeline.hmset_dict(key, items)
pipeline.expire(key, timeout=expire)
pipeline.hset(key, mapping=items)
pipeline.expire(key, time=expire)
result, _ = await pipeline.execute()

return result

async def hdel(self, key: str, fields: List[str]) -> int:
Expand Down
3 changes: 1 addition & 2 deletions aioradio/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
aioboto3==8.2.0
aiobotocore==1.1.2
aiojobs==0.3.0
aioredis==1.3.1
ddtrace==0.46.0
fakeredis==1.4.5
flask==1.1.2
git+https://github.com/aio-libs/aioredis@sean/aioredis-redis-py-compliance
httpx==0.17.0
mandrill==1.0.59
moto==1.3.16
Expand Down
11 changes: 9 additions & 2 deletions aioradio/tests/redis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def test_hash_redis_functions(cache):
}

result = await cache.hmset(key='complex_hash', items=items, expire=1)
assert result == 1
assert result == 4

result = await cache.hmget(key='complex_hash', fields=['name', 'team', 'apps', 'fake'])
assert 'fake' not in result
Expand All @@ -73,7 +73,7 @@ async def test_hash_redis_functions(cache):

items = {'state': 'TX', 'city': 'Austin', 'zipcode': '78745', 'addr1': '8103 Shiloh Ct.', 'addr2': ''}
result = await cache.hmset(key='address_hash', items=items, expire=1, use_json=False)
assert result == 1
assert result == 5
result = await cache.hgetall(key='address_hash', use_json=False)
assert result == items

Expand Down Expand Up @@ -102,6 +102,13 @@ async def test_set_one_item(payload, cache):
result = await cache.delete(key)
assert result == 1

await cache.set(key='set_simple_key', value='aioradio is superb', use_json=False)
result = await cache.get('set_simple_key', use_json=False)
assert result == 'aioradio is superb'
result = await cache.mget(items=['set_simple_key'], use_json=False)
assert result == ['aioradio is superb']


async def test_set_one_item_with_hashed_key(payload, cache):
"""Test set_one_item."""

Expand Down
4 changes: 3 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def cache(github_action):
if github_action:
pytest.skip('Skip test_set_one_item when running via Github Action')

cache_object = Redis(fakeredis=True)
cache_object = Redis(config={
'redis_primary_endpoint': 'prod-race2.gbngr1.ng.0001.use1.cache.amazonaws.com'
})
yield cache_object


Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
long_description = fileobj.read()

setup(name='aioradio',
version='0.11.7',
version='0.12.0',
description='Generic asynchronous i/o python utilities for AWS services (SQS, S3, DynamoDB, Secrets Manager), Redis, MSSQL (pyodbc), JIRA and more',
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -26,7 +26,6 @@
'aioredis',
'boto3',
'ddtrace',
'fakeredis',
'httpx',
'mandrill',
'orjson',
Expand Down

0 comments on commit 3732c15

Please sign in to comment.