diff --git a/HISTORY.rst b/HISTORY.rst index 23904ad..7117ee4 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,12 @@ History ======= +v0.12.0 (2021-03-08) +----------------------- + +* Use aioredis transactions performance fixed branch (sean/aioredis-redis-py-compliance) instead of version 1.3.1. + + v0.11.7 (2021-03-01) ----------------------- diff --git a/README.md b/README.md index 02f927b..83b5147 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,6 @@ async def main(): config = {'redis_primary_endpoint': 'your-redis-endpoint'} redis = Redis(config=config, use_json=True, expire=60, use_hashkey=False) - # since aioredis.create_redis_pool is a coroutine we need to instantiate the object within an async function - if redis.pool_task is not None: - redis.pool = await redis.pool_task - # we can override the global expire and since we are using json the cache_value will be converted to json await redis.set_one_item(cache_key='aioradio', cache_value={'a': 'alpha', 'number': 123}, expire=2) diff --git a/aioradio/redis.py b/aioradio/redis.py index 4eea6c3..15ff234 100644 --- a/aioradio/redis.py +++ b/aioradio/redis.py @@ -13,7 +13,6 @@ import aioredis import orjson -from fakeredis.aioredis import create_redis_pool as fake_redis_pool HASH_ALGO_MAP = { 'SHA1': hashlib.sha1, @@ -34,7 +33,6 @@ class Redis: config: Dict[str, Any] = dataclass_field(default_factory=dict) pool: aioredis.Redis = dataclass_field(init=False, repr=False) - pool_task: asyncio.coroutine = None # Set the redis pool min and max connections size pool_minsize: int = 5 @@ -51,24 +49,9 @@ class Redis: # retrieve value letting this class convert from json set use_json = True. use_json: bool = True - # used exclusively for pytest - fakeredis: bool = False - def __post_init__(self): - if self.fakeredis: - self.pool = asyncio.get_event_loop().run_until_complete(fake_redis_pool()) - else: - primary_endpoint = f'redis://{self.config["redis_primary_endpoint"]}' - loop = asyncio.get_event_loop() - if loop and loop.is_running(): - self.pool_task = loop.create_task( - aioredis.create_redis_pool(primary_endpoint, minsize=self.pool_minsize, maxsize=self.pool_maxsize)) - else: - self.pool = loop.run_until_complete( - aioredis.create_redis_pool(primary_endpoint, minsize=self.pool_minsize, maxsize=self.pool_maxsize)) - - def __del__(self): - self.pool.close() + primary_endpoint = self.config["redis_primary_endpoint"] + self.pool = aioredis.Redis(host=primary_endpoint) async def get(self, key: str, use_json: bool=None, encoding: Union[str, None]='utf-8') -> Any: """Check if an item is cached in redis. @@ -85,10 +68,13 @@ async def get(self, key: str, use_json: bool=None, encoding: Union[str, None]='u if use_json is None: use_json = self.use_json - value = await self.pool.get(key, encoding=encoding) + value = await self.pool.get(key) - if value is not None and use_json: - value = orjson.loads(value) + if value is not None: + if encoding is not None: + value = value.decode(encoding) + if use_json: + value = orjson.loads(value) return value @@ -107,12 +93,18 @@ async def mget(self, items: List[str], use_json: bool=None, encoding: Union[str, if use_json is None: use_json = self.use_json - values = await self.pool.mget(*items, encoding=encoding) + values = await self.pool.mget(*items) - if use_json: - values = [orjson.loads(val) if val is not None else None for val in values] + results = [] + for val in values: + if val is not None: + if encoding is not None: + val = val.decode(encoding) + if use_json: + val = orjson.loads(val) + results.append(val) - return values + return results async def set(self, key: str, value: str, expire: int=None, use_json: bool=None) -> int: """Set one key-value pair in redis. @@ -136,7 +128,7 @@ async def set(self, key: str, value: str, expire: int=None, use_json: bool=None) if use_json: value = orjson.dumps(value) - return await self.pool.set(key, value, expire=expire) + return await self.pool.set(key, value, ex=expire) async def delete(self, key: str) -> int: """Delete key from redis. @@ -166,10 +158,13 @@ async def hget(self, key: str, field: str, use_json: bool=None, encoding: Union[ if use_json is None: use_json = self.use_json - value = await self.pool.hget(key, field, encoding=encoding) + value = await self.pool.hget(key, field) - if value is not None and use_json: - value = orjson.loads(value) + if value is not None: + if encoding is not None: + value = value.decode(encoding) + if use_json: + value = orjson.loads(value) return value @@ -190,8 +185,10 @@ async def hmget(self, key: str, fields: List[str], use_json: bool=None, encoding use_json = self.use_json items = {} - for index, value in enumerate(await self.pool.hmget(key, *fields, encoding=encoding)): + for index, value in enumerate(await self.pool.hmget(key, *fields)): if value is not None: + if encoding is not None: + value = value.decode(encoding) if use_json: value = orjson.loads(value) items[fields[index]] = value @@ -216,13 +213,15 @@ async def hmget_many(self, keys: List[str], fields: List[str], use_json: bool=No pipeline = self.pool.pipeline() for key in keys: - pipeline.hmget(key, *fields, encoding=encoding) + pipeline.hmget(key, *fields) results = [] for values in await pipeline.execute(): items = {} for index, value in enumerate(values): if value is not None: + if encoding is not None: + value = value.decode(encoding) if use_json: value = orjson.loads(value) items[fields[index]] = value @@ -246,8 +245,12 @@ async def hgetall(self, key: str, use_json: bool=None, encoding: Union[str, None use_json = self.use_json items = {} - for hash_key, value in (await self.pool.hgetall(key, encoding=encoding)).items(): + for hash_key, value in (await self.pool.hgetall(key)).items(): + if encoding is not None: + hash_key = hash_key.decode(encoding) if value is not None: + if encoding is not None: + value = value.decode(encoding) if use_json: value = orjson.loads(value) items[hash_key] = value @@ -271,16 +274,20 @@ async def hgetall_many(self, keys: List[str], use_json: bool=None, encoding: Uni pipeline = self.pool.pipeline() for key in keys: - pipeline.hgetall(key, encoding=encoding) + pipeline.hgetall(key) results = [] for item in await pipeline.execute(): items = {} for key, value in item.items(): + if encoding is not None: + key = key.decode(encoding) if value is not None: + if encoding is not None: + value = value.decode(encoding) if use_json: value = orjson.loads(value) - items[key] = value + items[key] = value results.append(items) return results @@ -308,9 +315,9 @@ async def hset(self, key: str, field: str, value: str, use_json: bool=None, expi if use_json: value = orjson.dumps(value) - pipeline = self.pool.multi_exec() + pipeline = self.pool.pipeline() pipeline.hset(key, field, value) - pipeline.expire(key, timeout=expire) + pipeline.expire(key, time=expire) result, _ = await pipeline.execute() return result @@ -320,7 +327,7 @@ async def hmset(self, key: str, items: Dict[str, Any], use_json: bool=None, expi Args: key (str): cache key - items (List[str, Any]): list of redis hash key-value pairs + items (Dict[str, Any]): list of redis hash key-value pairs use_json (bool, optional): set object to json before writing to cache. Defaults to None. expire (int, optional): cache expiration. Defaults to None. @@ -339,10 +346,9 @@ async def hmset(self, key: str, items: Dict[str, Any], use_json: bool=None, expi items = modified_items pipeline = self.pool.pipeline() - pipeline.hmset_dict(key, items) - pipeline.expire(key, timeout=expire) + pipeline.hset(key, mapping=items) + pipeline.expire(key, time=expire) result, _ = await pipeline.execute() - return result async def hdel(self, key: str, fields: List[str]) -> int: diff --git a/aioradio/requirements.txt b/aioradio/requirements.txt index c4abaec..4a5c71d 100644 --- a/aioradio/requirements.txt +++ b/aioradio/requirements.txt @@ -1,10 +1,9 @@ aioboto3==8.2.0 aiobotocore==1.1.2 aiojobs==0.3.0 -aioredis==1.3.1 ddtrace==0.46.0 -fakeredis==1.4.5 flask==1.1.2 +git+https://github.com/aio-libs/aioredis@sean/aioredis-redis-py-compliance httpx==0.17.0 mandrill==1.0.59 moto==1.3.16 diff --git a/aioradio/tests/redis_test.py b/aioradio/tests/redis_test.py index e304090..c42b51a 100644 --- a/aioradio/tests/redis_test.py +++ b/aioradio/tests/redis_test.py @@ -53,7 +53,7 @@ async def test_hash_redis_functions(cache): } result = await cache.hmset(key='complex_hash', items=items, expire=1) - assert result == 1 + assert result == 4 result = await cache.hmget(key='complex_hash', fields=['name', 'team', 'apps', 'fake']) assert 'fake' not in result @@ -73,7 +73,7 @@ async def test_hash_redis_functions(cache): items = {'state': 'TX', 'city': 'Austin', 'zipcode': '78745', 'addr1': '8103 Shiloh Ct.', 'addr2': ''} result = await cache.hmset(key='address_hash', items=items, expire=1, use_json=False) - assert result == 1 + assert result == 5 result = await cache.hgetall(key='address_hash', use_json=False) assert result == items @@ -102,6 +102,13 @@ async def test_set_one_item(payload, cache): result = await cache.delete(key) assert result == 1 + await cache.set(key='set_simple_key', value='aioradio is superb', use_json=False) + result = await cache.get('set_simple_key', use_json=False) + assert result == 'aioradio is superb' + result = await cache.mget(items=['set_simple_key'], use_json=False) + assert result == ['aioradio is superb'] + + async def test_set_one_item_with_hashed_key(payload, cache): """Test set_one_item.""" diff --git a/conftest.py b/conftest.py index cf687af..ef14330 100644 --- a/conftest.py +++ b/conftest.py @@ -56,7 +56,9 @@ def cache(github_action): if github_action: pytest.skip('Skip test_set_one_item when running via Github Action') - cache_object = Redis(fakeredis=True) + cache_object = Redis(config={ + 'redis_primary_endpoint': 'prod-race2.gbngr1.ng.0001.use1.cache.amazonaws.com' + }) yield cache_object diff --git a/setup.py b/setup.py index 0aa03d8..c299768 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ long_description = fileobj.read() setup(name='aioradio', - version='0.11.7', + version='0.12.0', description='Generic asynchronous i/o python utilities for AWS services (SQS, S3, DynamoDB, Secrets Manager), Redis, MSSQL (pyodbc), JIRA and more', long_description=long_description, long_description_content_type="text/markdown", @@ -26,7 +26,6 @@ 'aioredis', 'boto3', 'ddtrace', - 'fakeredis', 'httpx', 'mandrill', 'orjson',