Skip to content

Commit

Permalink
Stop stager in a clean manner
Browse files Browse the repository at this point in the history
Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Feb 26, 2025
1 parent 2cf6b9a commit 104bd93
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
14 changes: 7 additions & 7 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def setup_queue(self, queue_name):
"""
self.logger.debug(f'Declaring queue "{queue_name}"')
cb = functools.partial(self.on_queue_declareok, userdata=queue_name)
# arguments = {"x-consumer-timeout":1800000} # 30 minutes in ms
self._channel.queue_declare(queue=queue_name, exclusive=False, callback=cb)

def on_queue_declareok(self, _unused_frame, userdata):
Expand Down Expand Up @@ -641,24 +642,23 @@ def on_consumer_cancelled(self, method_frame):
if self._channel:
self._channel.close()

def on_message(self, _unused_channel, basic_deliver, properties, body):
def on_message(self, _unused_channel, method_frame, properties, body):
"""Invoked by pika when a message is delivered from RabbitMQ. The
channel is passed for your convenience. The basic_deliver object that
channel is passed for your convenience. The method_frame object that
is passed in carries the exchange, routing key, delivery tag and
a redelivered flag for the message. The properties passed in is an
instance of BasicProperties with the message properties and the body
is the message that was sent.
:param pika.channel.Channel _unused_channel: The channel object
:param pika.Spec.Basic.Deliver: basic_deliver method
:param pika.Spec.Basic.Deliver: method_frame method
:param pika.Spec.BasicProperties: properties
:param bytes body: The message body
"""
self.logger.info(f"Received message #{basic_deliver.delivery_tag} from {properties}")
if isinstance(self._on_message_cb, Callable):
self._on_message_cb(_unused_channel, basic_deliver, properties, body)
self.acknowledge_message(basic_deliver.delivery_tag)
self.logger.info(f"Received message #{method_frame.delivery_tag} from {properties}")
self._on_message_cb(_unused_channel, method_frame, properties, body)
self.acknowledge_message(method_frame.delivery_tag)

def acknowledge_message(self, delivery_tag):
"""Acknowledge the message delivery from RabbitMQ by sending a
Expand Down
49 changes: 43 additions & 6 deletions src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(
rmq_queue,
policy,
prefetch_count=1,
signals=[signal.SIGTERM, signal.SIGINT, signal.SIGUSR1],
signals=[signal.SIGINT, signal.SIGUSR1],
):
self.o_queue = o_queue
self.cert = cert
Expand Down Expand Up @@ -586,6 +586,41 @@ def __init__(self, db_dir, store, dest_dir=None, stage_dir=None, db_type="dhdf5"

self.store = store

# For signal handling
self.released = False

self.signals = [signal.SIGINT, signal.SIGTERM, signal.SIGUSR1]

def signal_wrapper(self, name, pid):
def handler(signum, frame):
print(f"Received SIGNUM={signum} for {name}[pid={pid}]")
# We trigger the underlying signal handlers for all tasks
# This should only trigger RMQDomainDataLoaderTask

# TODO: I don't like this system to shutdown the pipeline on demand
# It's extremely easy to mess thing up with signals.. and it's
# not a robust solution (if a task is not managing correctly SIGINT
# the pipeline can explode)
for e in self._executors:
os.kill(e.pid, signal.SIGINT)
self.release_signals()
return handler

def init_signals(self):
self.released = False
self.original_handlers = {}
for sig in self.signals:
self.original_handlers[sig] = signal.getsignal(sig)
signal.signal(sig, self.signal_wrapper(self.__class__.__name__, os.getpid()))

def release_signals(self):
if not self.released:
# We put back all the signal handlers
for sig in self.signals:
signal.signal(sig, self.original_handlers[sig])

self.released = True

def add_user_action(self, obj):
"""
Adds an action to be performed at the data before storing them in the filesystem
Expand Down Expand Up @@ -618,15 +653,15 @@ def _parallel_execute(self, exec_vehicle_cls):
exec_vehicle_cls: The class to be used to generate entities
executing actions by reading data from i/o_queue(s).
"""
executors = list()
self._executors = list()
for a in self._tasks:
executors.append(exec_vehicle_cls(target=a))
self._executors.append(exec_vehicle_cls(target=a))

for e in executors:
for e in self._executors:
e.start()

print(f"{self.__class__.__name__} joining threads")
for e in executors:
print(f"{self.__class__.__name__} joining {len(self._executors)} threads")
for e in self._executors:
e.join()
print(f"{self.__class__.__name__} Threads are done")

Expand Down Expand Up @@ -692,10 +727,12 @@ def execute(self, policy):
f"Pipeline execute does not support policy: {policy}, please select from {Pipeline.supported_policies}"
)

self.init_signals()
# Create a pipeline of actions and link them with appropriate queues
self._link_pipeline(policy)
# Execute them
self._execute_tasks(policy)
self.release_signals()

@abstractmethod
def requires_model_update(self):
Expand Down

0 comments on commit 104bd93

Please sign in to comment.