Skip to content

Commit

Permalink
Adding redis and sqs as queue to LongRunningJobs
Browse files Browse the repository at this point in the history
  • Loading branch information
tim.reichard committed Apr 13, 2021
1 parent 6ec49fb commit 959ca56
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 54 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
default_language_version:
python: python3.8
python: python3.9
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
Expand All @@ -15,7 +15,7 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/PyCQA/isort
rev: 5.7.0
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/myint/docformatter
Expand Down
6 changes: 6 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ History
=======


v0.13.1 (2021-04-13)
-----------------------

* Updating LongRunningJob to use either 'sqs' or 'redis' as the queue mechanism.


v0.13.0 (2021-04-12)
-----------------------

Expand Down
147 changes: 113 additions & 34 deletions aioradio/long_running_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,56 @@
# pylint: disable=broad-except
# pylint: disable=c-extension-no-member
# pylint: disable=logging-fstring-interpolation
# pylint: disable=too-many-instance-attributes

import asyncio
import traceback
from dataclasses import dataclass, field
from dataclasses import dataclass
from time import time
from typing import Any, Dict
from uuid import uuid4

import orjson

from aioradio.aws import sqs
from aioradio.redis import Redis


@dataclass
class LongRunningJobs:
"""Worker that continually pulls from Redis list (implemented like queue),
running a job using request parameters conveyed in the message.
"""Worker that continually pulls from queue (either SQS or Redis list
implemented like queue), running a job using request parameters conveyed in
the message.
Also has a pre-processing function to send messages to the Redis
list.
Also has functions to send messages to the queue, and to check if
job is complete.
"""

name: str # Name of the long running job used to identify between multiple jobs running within app.
redis_host: str
cache_expiration: int = 3600
worker_active: Dict[str, Any] = field(default_factory=dict)
worker_active: bool = False

# choose between sqs or redis
queue_service: str = 'sqs'
# if using sqs than define the queue name and aws region
sqs_queue: str = None
sqs_region: str = None

# job_timeout value should be a factor of 2x or 3x above the max time a job finishes and corresponds
# to the visibility_timeout when queue_service = 'sqs' and message re-entry into
# <self.name>-not-started queue when queue_service = 'redis'. Setting job_timeout is optional when
# instantiating the class as it can also be defined when issuing the start_worker method.
job_timeout: float = 30

def __post_init__(self):

self.queue_service = self.queue_service.lower()
if self.queue_service not in ['sqs', 'redis']:
raise ValueError("queue_service must be either 'sqs' or 'redis'.")

self.name_to_job = {self.name: None}

self.cache = Redis(
config={
'redis_primary_endpoint': self.redis_host,
Expand Down Expand Up @@ -59,11 +82,10 @@ async def check_job_status(self, uuid: str) -> Dict[str, Any]:

return result

async def send_message(self, job_name: str, params: Dict[str, Any], cache_key: str=None) -> Dict[str, str]:
"""Send message to Redis list.
async def send_message(self, params: Dict[str, Any], cache_key: str=None) -> Dict[str, str]:
"""Send message to queue.
Args:
job_name (str): Name of the long running job. Used to identify between multiple jobs running within app.
params (Dict[str, Any]): Request parameters needed for job
cache_key (str, optional): Results cache key. Defaults to None.
Expand All @@ -82,62 +104,118 @@ async def send_message(self, job_name: str, params: Dict[str, Any], cache_key: s

result = {}
try:
self.cache.pool.rpush(f'{job_name}-not-started', orjson.dumps(items).decode())
msg = orjson.dumps(items).decode()
if self.queue_service == 'sqs':
entries = [{'Id': str(uuid4()), 'MessageBody': msg, 'MessageGroupId': self.name}]
await sqs.send_messages(queue=self.sqs_queue, region=self.sqs_region, entries=entries)
else:
self.cache.pool.rpush(f'{self.name}-not-started', msg)

await self.cache.hmset(key=identifier, items=items)
result['uuid'] = identifier
except Exception as err:
result['error'] = str(err)

return result

async def start_worker(self, job_name: str, job: Any, job_timeout: float=30):
"""Continually run the worker."""
async def start_worker(self, job: Any, job_timeout: float=30):
"""Continually run the worker.
Args:
job (Any): Long running job as an async function
job_timeout (float): Job should finish before given amount of time in seconds
"""

if self.name_to_job[self.name] is not None and self.name_to_job[self.name] != job:
raise TypeError('LongRunningJob class can only be assigned to process one job!')

self.worker_active[job_name] = True
self.job_timeout = job_timeout
self.worker_active = True
while True:
while self.worker_active[job_name]:

while self.worker_active:
try:
# run job the majority of the time pulling up to 10 messages to process
for _ in range(10):
await self.__pull_messages_and_run_jobs__(job_name, job)
if self.queue_service == 'sqs':
await self.__sqs_pull_messages_and_run_jobs__(job)
else:
await self.__redis_pull_messages_and_run_jobs__(job)

# verify processing only a fraction of the time
await self.__verify_processing__(job_name, job_timeout)
if self.queue_service == 'redis':
await self.__verify_processing__()
except asyncio.CancelledError:
print(traceback.format_exc())
break
except Exception:
print(traceback.format_exc())
await asyncio.sleep(30)

await asyncio.sleep(1)

async def stop_worker(self, job_name: str):
"""Stop worker associated with job_name.
async def stop_worker(self):
"""Stop worker."""

self.worker_active = False

async def __sqs_pull_messages_and_run_jobs__(self, job: Any):
"""Pull messages one at a time and run job.
Args:
job_name (str): Name of the long running job. Used to identify between multiple jobs running within app.
job (Any): Long running job as an async function
Raises:
IOError: Redis access failed
"""

self.worker_active[job_name] = False
msg = await sqs.get_messages(
queue=self.sqs_queue,
region=self.sqs_region,
wait_time=1,
visibility_timeout=self.job_timeout
)
if not msg:
await asyncio.sleep(0.1)
else:
body = orjson.loads(msg[0]['Body'])
key = body['cache_key']

data = None if key is None else await self.cache.get(key)
if data is None:
# No results found in cache so run the job
data = await job(body['params'])

# Set the cached parameter based key with results
if key is not None and not await self.cache.set(key, data):
raise IOError(f"Setting cache string failed for cache_key: {key}")

# Update the hashed UUID with processing results
await self.cache.hmset(key=body['uuid'], items={**body, **{'results': data, 'job_done': True}})
entries = [{'Id': str(uuid4()), 'ReceiptHandle': msg[0]['ReceiptHandle']}]
await sqs.delete_messages(queue=self.sqs_queue, region=self.sqs_region, entries=entries)

async def __pull_messages_and_run_jobs__(self, job_name: str, job: Any):
async def __redis_pull_messages_and_run_jobs__(self, job: Any):
"""Pull messages one at a time and run job.
Args:
job (Any): Long running job as an async function
Raises:
IOError: Redis access failed
"""

# in the future convert lpop to lmove and also look into integrating with async aioredis
msg = self.cache.pool.lpop(f'{job_name}-not-started')
msg = self.cache.pool.lpop(f'{self.name}-not-started')
if not msg:
await asyncio.sleep(0.1)
else:
body = orjson.loads(msg)
key = body['cache_key']

# Add start time and push msg to <job_name>-in-process
# Add start time and push msg to <self.name>-in-process
body['start_time'] = time()
self.cache.pool.rpush(f'{job_name}-in-process', orjson.dumps(body).decode())
self.cache.pool.rpush(f'{self.name}-in-process', orjson.dumps(body).decode())

data = None if key is None else await self.cache.get(key)
if data is None:
Expand All @@ -151,12 +229,13 @@ async def __pull_messages_and_run_jobs__(self, job_name: str, job: Any):
# Update the hashed UUID with processing results
await self.cache.hmset(key=body['uuid'], items={**body, **{'results': data, 'job_done': True}})

async def __verify_processing__(self, job_name: str, job_timeout: float):
async def __verify_processing__(self):
"""Verify processing completed fixing issues related to app crashing or
scaling down servers."""
scaling down servers when using queue_service = 'redis'.
"""

for _ in range(10):
msg = self.cache.pool.lpop(f'{job_name}-in-process')
msg = self.cache.pool.lpop(f'{self.name}-in-process')
if not msg:
break

Expand All @@ -165,19 +244,19 @@ async def __verify_processing__(self, job_name: str, job_timeout: float):
if job_done is None:
pass # if the cache is expired then we can typically ignore doing anything
elif not job_done:
if (time() - body['start_time']) > job_timeout:
print(f'Failed processing uuid: {body["uuid"]} in {job_timeout} seconds. Pushing msg back to {job_name}-not-started.')
self.cache.pool.rpush(f'{job_name}-not-started', msg)
if (time() - body['start_time']) > self.job_timeout:
print(f'Failed processing uuid: {body["uuid"]} in {self.job_timeout} seconds. \
Pushing msg back to {self.name}-not-started.')
self.cache.pool.rpush(f'{self.name}-not-started', msg)
else:
self.cache.pool.rpush(f'{job_name}-in-process', msg)
self.cache.pool.rpush(f'{self.name}-in-process', msg)

@staticmethod
async def build_cache_key(params: Dict[str, Any], separator='|') -> str:
"""build a cache key from a dictionary object.
Concatenate and
"""Build a cache key from a dictionary object. Concatenate and
normalize key-values from an unnested dict, taking care of sorting the
keys and each of their values (if a list).
Args:
params (Dict[str, Any]): dict object to use to build cache key
separator (str, optional): character to use as a separator in the cache key. Defaults to '|'.
Expand Down
2 changes: 1 addition & 1 deletion aioradio/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aioboto3==8.3.0
aiobotocore==1.3.0
aiojobs==0.3.0
ddtrace==0.47.0
ddtrace==0.48.0
flask==1.1.2
httpx==0.17.1
mandrill==1.0.60
Expand Down
28 changes: 14 additions & 14 deletions aioradio/tests/long_running_jobs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
pytestmark = pytest.mark.asyncio


async def test_lrj_worker(github_action, lrj):
async def test_lrj_worker(github_action, lrj1, lrj2):
"""Test test_lrj_worker."""

if github_action:
Expand All @@ -33,36 +33,36 @@ def delay(delay, result):
return delay(**params)


worker1 = lrj.start_worker(job_name='job1', job=job1, job_timeout=3)
worker1 = lrj1.start_worker(job=job1, job_timeout=3)
create_task(worker1)
worker2 = lrj.start_worker(job_name='job2', job=job2, job_timeout=3)
worker2 = lrj2.start_worker(job=job2, job_timeout=3)
create_task(worker2)

params = {'delay': 1, 'result': randint(0, 100)}
result1 = await lrj.send_message(job_name='job1', params=params)
result1 = await lrj1.send_message(params=params)
assert 'uuid' in result1 and 'error' not in result1

cache_key = await lrj.build_cache_key(params=params)
result2 = await lrj.send_message(job_name='job2', params=params, cache_key=cache_key)
cache_key = await lrj2.build_cache_key(params=params)
result2 = await lrj2.send_message(params=params, cache_key=cache_key)
assert 'uuid' in result2 and 'error' not in result2

await sleep(2)
await sleep(1.5)

result = await lrj.check_job_status(result1['uuid'])
result = await lrj1.check_job_status(result1['uuid'])
assert result['job_done'] and result['results'] == params['result']

result = await lrj.check_job_status(result2['uuid'])
result = await lrj2.check_job_status(result2['uuid'])
assert result['job_done'] and result['results'] == params['result']

result3 = await lrj.send_message(job_name='job1', params=params, cache_key=cache_key)
result3 = await lrj1.send_message(params=params, cache_key=cache_key)
await sleep(0.333)
assert 'uuid' in result3 and 'error' not in result3
result = await lrj.check_job_status(result3['uuid'])
result = await lrj1.check_job_status(result3['uuid'])
assert result['job_done'] and result['results'] == params['result']

await sleep(5)
result = await lrj.check_job_status(result1['uuid'])
result = await lrj2.check_job_status(result1['uuid'])
assert not result['job_done'] and 'error' in result

await lrj.stop_worker(job_name='job1')
await lrj.stop_worker(job_name='job2')
await lrj1.stop_worker()
await lrj2.stop_worker()
25 changes: 23 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,37 @@ def cache(github_action):


@pytest.fixture(scope='module')
def lrj(github_action):
def lrj1(github_action):
"""LongRunningProcess class object."""

if github_action:
pytest.skip('Skip tests using LongRunningJobs when running via Github Action')

lrj = LongRunningJobs(
name='lrj1',
redis_host='prod-race2.gbngr1.ng.0001.use1.cache.amazonaws.com',
cache_expiration=5
cache_expiration=5,
sqs_queue='NARWHAL_QUEUE_SANDBOX.fifo',
sqs_region='us-east-1'
)

yield lrj


@pytest.fixture(scope='module')
def lrj2(github_action):
"""LongRunningProcess class object."""

if github_action:
pytest.skip('Skip tests using LongRunningJobs when running via Github Action')

lrj = LongRunningJobs(
name='lrj2',
redis_host='prod-race2.gbngr1.ng.0001.use1.cache.amazonaws.com',
cache_expiration=5,
queue_service='redis'
)

yield lrj


Expand Down
Loading

0 comments on commit 959ca56

Please sign in to comment.