Skip to content

Commit

Permalink
add new Celery signature to mark a run as complete
Browse files Browse the repository at this point in the history
  • Loading branch information
bgunnar5 committed Feb 17, 2025
1 parent e48cf05 commit 8ee2aff
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
29 changes: 29 additions & 0 deletions merlin/common/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from merlin.common.sample_index import uniform_directories
from merlin.common.sample_index_factory import create_hierarchy
from merlin.config.utils import Priority, get_priority
from merlin.db_scripts.db_interaction import MerlinDatabase
from merlin.exceptions import HardFailException, InvalidChainException, RestartException, RetryException
from merlin.router import stop_workers
from merlin.spec.expansion import parameter_substitutions_for_cmd, parameter_substitutions_for_sample
Expand Down Expand Up @@ -739,6 +740,29 @@ def chordfinisher(*args, **kwargs): # pylint: disable=W0613
return "SYNC"


@shared_task(
autoretry_for=retry_exceptions,
retry_backoff=True,
name="merlin:mark_run_as_complete",
priority=get_priority(Priority.LOW),
)
def mark_run_as_complete(study_workspace: str) -> str:
"""
Mark this run as complete and save that to the database.
Args:
study_workspace: The output workspace for this run.
Returns:
A string denoting that this run has completed.
"""
merlin_db = MerlinDatabase()
db_run = merlin_db.get_run_from_workspace(study_workspace)
db_run.run_complete = True
db_run.save()
return "Run Completed"


@shared_task(
autoretry_for=retry_exceptions,
retry_backoff=True,
Expand Down Expand Up @@ -777,5 +801,10 @@ def queue_merlin_study(study, adapter):
)
for chain_group in groups_of_chains[1:]
)

# Append the final task that marks the run as complete
final_task = mark_run_as_complete.si(study.workspace).set(queue=egraph.step(groups_of_chains[1][0][0]).get_task_queue())
celery_dag = celery_dag | final_task

LOG.info("Launching tasks.")
return celery_dag.delay(None)
17 changes: 16 additions & 1 deletion merlin/db_scripts/db_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def create_run(self, study_name: str, workspace: str, queues: List[str], *args,

return db_study.create_run(workspace=workspace, queues=queues, *args, **kwargs)

def get_run(self, run_id: str):
def get_run(self, run_id: str) -> DatabaseRun:
"""
Given a run id, retrieve the associated run from the database.
Expand All @@ -169,6 +169,21 @@ def get_run(self, run_id: str):
"""
return DatabaseRun.load(run_id, self.backend)

def get_run_from_workspace(self, workspace: str) -> DatabaseRun:
"""
Given an output workspace for a run, find the run metadata file and load
the run from it.
Args:
workspace: The output workspace for a run.
Returns:
A [`DatabaseRun`][merlin.db_scripts.db_run.DatabaseRun] instance representing
the run that was queried.
"""
run_metadata_filepath = DatabaseRun.get_metadata_filepath(workspace)
return DatabaseRun.load_from_metadata_file(run_metadata_filepath, self.backend)

def get_all_runs(self) -> List[DatabaseRun]:
"""
Get every run that's currently in the database.
Expand Down
1 change: 0 additions & 1 deletion merlin/db_scripts/db_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def run_complete(self, value: bool):
value: The completion status of the run.
"""
self.run_info.run_complete = value
self.save()

def get_metadata_file(self) -> str:
"""
Expand Down
10 changes: 10 additions & 0 deletions merlin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

from merlin import VERSION, router
from merlin.ascii_art import banner_small
from merlin.db_scripts.db_interaction import MerlinDatabase
from merlin.examples.generator import list_examples, setup_example
from merlin.log_formatter import setup_logging
from merlin.server.server_commands import config_server, init_server, restart_server, start_server, status_server, stop_server
Expand Down Expand Up @@ -161,6 +162,15 @@ def process_run(args: Namespace) -> None:
pgen_file=args.pgen_file,
pargs=args.pargs,
)

merlin_db = MerlinDatabase()
db_run = merlin_db.create_run(
study_name=study.expanded_spec.name,
workspace=study.workspace,
queues=study.expanded_spec.get_queue_list(["all"]),
workers=study.expanded_spec.get_worker_names(),
)

router.run_task_server(study, args.run_mode)


Expand Down

0 comments on commit 8ee2aff

Please sign in to comment.