From b210fcd4ac13b723e5eadcaab7c4897d01d05af1 Mon Sep 17 00:00:00 2001 From: tasgon Date: Fri, 5 Jan 2024 19:27:47 -0500 Subject: [PATCH] add subscribers and use pipelines --- asyncio_redis_queues.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/asyncio_redis_queues.py b/asyncio_redis_queues.py index c96d48e..11e3d86 100644 --- a/asyncio_redis_queues.py +++ b/asyncio_redis_queues.py @@ -6,7 +6,7 @@ import dotenv import asyncio import redis.asyncio as redis -from typing import Union, Dict +from typing import Union, Dict, Callable, List dotenv.load_dotenv() @@ -32,13 +32,14 @@ async def get_status(self): return await self.queue.redis_client.hget(JOB_STATUS_NAME, self.id) async def notify(self, payload, status): - return await asyncio.gather( - self.queue.redis_client.hset(JOB_STATUS_NAME, self.id, json.dumps(status)), - self.queue.redis_client.publish(CHANNEL_NAME, json.dumps({ - "job_id": self.id, - "payload": payload - })), - ) + pipe = self.queue.redis_client.pipeline() + pipe.hset(JOB_STATUS_NAME, self.id, json.dumps(status)) + pipe.publish(CHANNEL_NAME, json.dumps({ + "job_id": self.id, + "payload": payload + })) + await pipe.execute() + def __repr__(self) -> str: return f"Job(id: {self.id}, queue: {self.queue})" @@ -55,16 +56,19 @@ def __init__(self, redis_client: redis.Redis): self.redis_client = redis_client self.listeners: Dict[str, asyncio.Future] = {} self.loop_task = asyncio.get_event_loop().create_task(self.loop()) + self.callbacks: List[Callable[[object]]] = [] async def loop(self): self.pubsub = self.redis_client.pubsub() await self.pubsub.subscribe(CHANNEL_NAME) - while True: + while asyncio.get_event_loop().is_running(): try: message = await self.pubsub.get_message(ignore_subscribe_messages=True) if message is None: continue data = message["data"].decode("utf-8") data = json.loads(data) + for callback in self.callbacks: + callback(data["job_id"], data["payload"]) if data["job_id"] in self.listeners: job_id = data["job_id"] self.listeners[job_id].set_result(data["payload"])