Skip to content

Commit

Permalink
🔨 Adding a job counter to address Semaphore issues (#408)
Browse files Browse the repository at this point in the history
* 🔨 Adding a job counter to address Semaphore issues

* 🧪 Test function for semaphore blocker
  • Loading branch information
rm-21 authored Oct 30, 2023
1 parent 9109c2e commit ab2dda2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
25 changes: 21 additions & 4 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ def __init__(
self.on_job_start = on_job_start
self.on_job_end = on_job_end
self.after_job_end = after_job_end
self.sem = asyncio.BoundedSemaphore(max_jobs)

self.max_jobs = max_jobs
self.sem = asyncio.BoundedSemaphore(max_jobs + 1)
self.job_counter: int = 0

self.job_timeout_s = to_seconds(job_timeout)
self.keep_result_s = to_seconds(keep_result)
self.keep_result_forever = keep_result_forever
Expand Down Expand Up @@ -374,13 +378,13 @@ async def _poll_iteration(self) -> None:
return
count = min(burst_jobs_remaining, count)
if self.allow_pick_jobs:
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
if self.job_counter < self.max_jobs:
now = timestamp_ms()
job_ids = await self.pool.zrangebyscore(
self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now
)

await self.start_jobs(job_ids)
await self.start_jobs(job_ids)

if self.allow_abort_jobs:
await self._cancel_aborted_jobs()
Expand Down Expand Up @@ -419,12 +423,23 @@ async def _cancel_aborted_jobs(self) -> None:
self.aborting_tasks.update(aborted)
await self.pool.zrem(abort_jobs_ss, *aborted)

def _release_sem_dec_counter_on_complete(self) -> None:
self.job_counter = self.job_counter - 1
self.sem.release()

async def start_jobs(self, job_ids: List[bytes]) -> None:
"""
For each job id, get the job definition, check it's not running and start it in a task
"""
for job_id_b in job_ids:
await self.sem.acquire()

if self.job_counter >= self.max_jobs:
self.sem.release()
return None

self.job_counter = self.job_counter + 1

job_id = job_id_b.decode()
in_progress_key = in_progress_key_prefix + job_id
async with self.pool.pipeline(transaction=True) as pipe:
Expand All @@ -433,6 +448,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
score = await pipe.zscore(self.queue_name, job_id)
if ongoing_exists or not score:
# job already started elsewhere, or already finished and removed from queue
self.job_counter = self.job_counter - 1
self.sem.release()
logger.debug('job %s already running elsewhere', job_id)
continue
Expand All @@ -445,11 +461,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
await pipe.execute()
except (ResponseError, WatchError):
# job already started elsewhere since we got 'existing'
self.job_counter = self.job_counter - 1
self.sem.release()
logger.debug('multi-exec error, job %s already started elsewhere', job_id)
else:
t = self.loop.create_task(self.run_job(job_id, int(score)))
t.add_done_callback(lambda _: self.sem.release())
t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete())
self.tasks[job_id] = t

async def run_job(self, job_id: str, score: int) -> None: # noqa: C901
Expand Down
30 changes: 30 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,36 @@ async def test(ctx):
assert result['called'] == 4


async def test_job_cancel_on_max_jobs(arq_redis: ArqRedis, worker, caplog):
async def longfunc(ctx):
await asyncio.sleep(3600)

async def wait_and_abort(job, delay=0.1):
await asyncio.sleep(delay)
assert await job.abort() is True

caplog.set_level(logging.INFO)
await arq_redis.zadd(abort_jobs_ss, {b'foobar': int(1e9)})
job = await arq_redis.enqueue_job('longfunc', _job_id='testing')

worker: Worker = worker(
functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1, max_jobs=1
)
assert worker.jobs_complete == 0
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0
await asyncio.gather(wait_and_abort(job), worker.main())
await worker.main()
assert worker.jobs_complete == 0
assert worker.jobs_failed == 1
assert worker.jobs_retried == 0
log = re.sub(r'\d+.\d\ds', 'X.XXs', '\n'.join(r.message for r in caplog.records))
assert 'X.XXs → testing:longfunc()\n X.XXs ⊘ testing:longfunc aborted' in log
assert worker.aborting_tasks == set()
assert worker.tasks == {}
assert worker.job_tasks == {}


async def test_worker_timezone_defaults_to_system_timezone(worker):
worker = worker(functions=[func(foobar)])
assert worker.timezone is not None
Expand Down

0 comments on commit ab2dda2

Please sign in to comment.