From dea1b75d2dad73f077a5934d05b221a88479d13e Mon Sep 17 00:00:00 2001 From: Sylvain Leclerc Date: Fri, 6 Dec 2024 17:35:46 +0100 Subject: [PATCH] fix(tasks): frozen tasks with pgpool (#2264) This work aims at fixing issues encountered with load balanced pgpool, where a commit is not always instantly visible in another session (when read from replica). Fixes are: - restore a workaround which consists in re-trying the read until success - ensure exceptions occurring in the task execution thread are ALL caught Also removing a technical debt: - remove optional session from task job repo which was just used in tests. It did not work with code which uses both the repo and also the session from the singleton db.session. Signed-off-by: Sylvain Leclerc --- antarest/core/tasks/repository.py | 15 +---- antarest/core/tasks/service.py | 105 +++++++++++++++++------------- antarest/core/utils/utils.py | 2 +- tests/core/test_tasks.py | 46 ++++++------- 4 files changed, 85 insertions(+), 83 deletions(-) diff --git a/antarest/core/tasks/repository.py b/antarest/core/tasks/repository.py index 0a2028db33..74c9e84b78 100644 --- a/antarest/core/tasks/repository.py +++ b/antarest/core/tasks/repository.py @@ -27,15 +27,6 @@ class TaskJobRepository: Database connector to manage Tasks/Jobs entities. """ - def __init__(self, session: t.Optional[Session] = None): - """ - Initialize the repository. - - Args: - session: Optional SQLAlchemy session to be used. - """ - self._session = session - @property def session(self) -> Session: """ @@ -44,11 +35,7 @@ def session(self) -> Session: Returns: SQLAlchemy session. """ - if self._session is None: - # Get or create the session from a context variable (thread local variable) - return db.session - # Get the user-defined session - return self._session + return db.session def save(self, task: TaskJob) -> TaskJob: session = self.session diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index f992e227c6..a97a2fedf5 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -38,6 +38,7 @@ ) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.core.utils.utils import retry from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult logger = logging.getLogger(__name__) @@ -390,35 +391,41 @@ def _run_task( task_id: str, custom_event_messages: t.Optional[CustomTaskEventMessages] = None, ) -> None: - # attention: this function is executed in a thread, not in the main process - with db(): - task = db.session.query(TaskJob).get(task_id) - task_type = task.type - study_id = task.ref_id + # We need to catch all exceptions so that the calling thread is guaranteed + # to not die + try: + # attention: this function is executed in a thread, not in the main process + with db(): + # Important to keep this retry for now, + # in case commit is not visible (read from replica ...) + task = retry(lambda: self.repo.get_or_raise(task_id)) + task_type = task.type + study_id = task.ref_id - self.event_bus.push( - Event( - type=EventType.TASK_RUNNING, - payload=TaskEventPayload( - id=task_id, - message=custom_event_messages.running - if custom_event_messages is not None - else f"Task {task_id} is running", - type=task_type, - study_id=study_id, - ).model_dump(), - permissions=PermissionInfo(public_mode=PublicMode.READ), - channel=EventChannelDirectory.TASK + task_id, + self.event_bus.push( + Event( + type=EventType.TASK_RUNNING, + payload=TaskEventPayload( + id=task_id, + message=custom_event_messages.running + if custom_event_messages is not None + else f"Task {task_id} is running", + type=task_type, + study_id=study_id, + ).model_dump(), + permissions=PermissionInfo(public_mode=PublicMode.READ), + channel=EventChannelDirectory.TASK + task_id, + ) ) - ) - logger.info(f"Starting task {task_id}") - with db(): - db.session.query(TaskJob).filter(TaskJob.id == task_id).update({TaskJob.status: TaskStatus.RUNNING.value}) - db.session.commit() - logger.info(f"Task {task_id} set to RUNNING") + logger.info(f"Starting task {task_id}") + with db(): + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + {TaskJob.status: TaskStatus.RUNNING.value} + ) + db.session.commit() + logger.info(f"Task {task_id} set to RUNNING") - try: with db(): # We must use the DB session attached to the current thread result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus)) @@ -463,29 +470,35 @@ def _run_task( err_msg = f"Task {task_id} failed: Unhandled exception {exc}" logger.error(err_msg, exc_info=exc) - with db(): - result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback." - db.session.query(TaskJob).filter(TaskJob.id == task_id).update( - { - TaskJob.status: TaskStatus.FAILED.value, - TaskJob.result_msg: result_msg, - TaskJob.result_status: False, - TaskJob.completion_date: datetime.datetime.utcnow(), - } + try: + with db(): + result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback." + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result_msg: result_msg, + TaskJob.result_status: False, + TaskJob.completion_date: datetime.datetime.utcnow(), + } + ) + db.session.commit() + + message = err_msg if custom_event_messages is None else custom_event_messages.end + self.event_bus.push( + Event( + type=EventType.TASK_FAILED, + payload=TaskEventPayload( + id=task_id, message=message, type=task_type, study_id=study_id + ).model_dump(), + permissions=PermissionInfo(public_mode=PublicMode.READ), + channel=EventChannelDirectory.TASK + task_id, + ) ) - db.session.commit() - - message = err_msg if custom_event_messages is None else custom_event_messages.end - self.event_bus.push( - Event( - type=EventType.TASK_FAILED, - payload=TaskEventPayload( - id=task_id, message=message, type=task_type, study_id=study_id - ).model_dump(), - permissions=PermissionInfo(public_mode=PublicMode.READ), - channel=EventChannelDirectory.TASK + task_id, + except Exception as inner_exc: + logger.error( + f"An exception occurred while handling execution error of task {task_id}: {inner_exc}", + exc_info=inner_exc, ) - ) def get_task_progress(self, task_id: str, params: RequestParameters) -> t.Optional[int]: task = self.repo.get_or_raise(task_id) diff --git a/antarest/core/utils/utils.py b/antarest/core/utils/utils.py index c748420549..2940db582a 100644 --- a/antarest/core/utils/utils.py +++ b/antarest/core/utils/utils.py @@ -105,7 +105,7 @@ def retry(func: t.Callable[[], T], attempts: int = 10, interval: float = 0.5) -> attempt += 1 return func() except Exception as e: - logger.info(f"💤 Sleeping {interval} second(s)...") + logger.info(f"💤 Sleeping {interval} second(s) before retry...", exc_info=e) time.sleep(interval) caught_exception = e raise caught_exception or ShouldNotHappenException() diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 139db56691..5a8b490047 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -204,21 +204,22 @@ def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return TaskResult(success=True, message="") -def test_repository(db_session: Session) -> None: +@with_db_context +def test_repository() -> None: # Prepare two users in the database user1_id = 9 - db_session.add(User(id=user1_id, name="John")) + db.session.add(User(id=user1_id, name="John")) user2_id = 10 - db_session.add(User(id=user2_id, name="Jane")) - db_session.commit() + db.session.add(User(id=user2_id, name="Jane")) + db.session.commit() # Create a RawStudy in the database study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" - db_session.add(RawStudy(id=study_id, name="foo", version="860")) - db_session.commit() + db.session.add(RawStudy(id=study_id, name="foo", version="860")) + db.session.commit() # Create a TaskJobService - task_job_repo = TaskJobRepository(db_session) + task_job_repo = TaskJobRepository() new_task = TaskJob(name="foo", owner_id=user1_id, type=TaskType.COPY) @@ -282,10 +283,10 @@ def test_repository(db_session: Session) -> None: assert len(new_task.logs) == 2 assert new_task.logs[0].message == "hello" - assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 + assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 task_job_repo.delete(new_task.id) - assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 + assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 assert task_job_repo.get(new_task.id) is None @@ -390,21 +391,22 @@ def test_cancel_orphan_tasks( assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds -def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None: +@with_db_context +def test_get_progress(admin_user: JWTUser, core_config: Config, event_bus: IEventBus) -> None: # Prepare two users in the database user1_id = 9 - db_session.add(User(id=user1_id, name="John")) + db.session.add(User(id=user1_id, name="John")) user2_id = 10 - db_session.add(User(id=user2_id, name="Jane")) - db_session.commit() + db.session.add(User(id=user2_id, name="Jane")) + db.session.commit() # Create a RawStudy in the database study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" - db_session.add(RawStudy(id=study_id, name="foo", version="860")) - db_session.commit() + db.session.add(RawStudy(id=study_id, name="foo", version="860")) + db.session.commit() # Create a TaskJobService - task_job_repo = TaskJobRepository(db_session) + task_job_repo = TaskJobRepository() # User 1 launches a ts generation first_task = TaskJob( @@ -451,12 +453,12 @@ def test_get_progress(db_session: Session, admin_user: JWTUser, core_config: Con service.get_task_progress(wrong_id, RequestParameters(user)) +@with_db_context def test_ts_generation_task( tmp_path: Path, core_config: Config, admin_user: JWTUser, raw_study_service: RawStudyService, - db_session: Session, ) -> None: # ======================= # SET UP @@ -465,7 +467,7 @@ def test_ts_generation_task( event_bus = DummyEventBusService() # Create a TaskJobService and add tasks - task_job_repo = TaskJobRepository(db_session) + task_job_repo = TaskJobRepository() # Create a TaskJobService task_job_service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus) @@ -474,8 +476,8 @@ def test_ts_generation_task( raw_study_path = tmp_path / "study" regular_user = User(id=99, name="regular") - db_session.add(regular_user) - db_session.commit() + db.session.add(regular_user) + db.session.commit() raw_study = RawStudy( id="my_raw_study", @@ -490,8 +492,8 @@ def test_ts_generation_task( path=str(raw_study_path), ) study_metadata_repository = StudyMetadataRepository(Mock(), None) - db_session.add(raw_study) - db_session.commit() + db.session.add(raw_study) + db.session.commit() # Set up the Raw Study raw_study_service.create(raw_study)