Skip to content

Commit

Permalink
Start getting callbacks working again
Browse files Browse the repository at this point in the history
  • Loading branch information
ashb committed Dec 9, 2024
1 parent 49bf21d commit 7940e39
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 787 deletions.
102 changes: 23 additions & 79 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,37 @@
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING

from pydantic import BaseModel

from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.typing_compat import Self


class CallbackRequest:
class CallbackRequest(BaseModel):
"""
Base Class with information about the callback to be executed.
:param full_filepath: File Path to use to run the callback
:param msg: Additional Message that can be used for logging
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""

def __init__(
self,
full_filepath: str,
processor_subdir: str | None = None,
msg: str | None = None,
):
self.full_filepath = full_filepath
self.processor_subdir = processor_subdir
self.msg = msg

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
full_filepath: str
"""File Path to use to run the callback"""
processor_subdir: str | None = None
"""Directory used by Dag Processor when parsed the dag"""
msg: str | None = None
"""Additional Message that can be used for logging to determine failure/zombie"""

def __repr__(self):
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)
to_json = BaseModel.model_dump_json

@classmethod
def from_json(cls, json_str: str):
json_object = json.loads(json_str)
return cls(**json_object)
def from_json(cls, data: str | bytes | bytearray) -> Self:
return cls.model_validate_json(data)


class TaskCallbackRequest(CallbackRequest):
Expand All @@ -67,25 +55,12 @@ class TaskCallbackRequest(CallbackRequest):
A Class with information about the success/failure TI callback to be executed. Currently, only failure
callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess.
:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.task_callback_type = task_callback_type
ti: ti_datamodel.TaskInstance
"""Simplified Task Instance representation"""
task_callback_type: TaskInstanceState | None = None
"""Whether on success, on failure, on retry"""

@property
def is_failure_callback(self) -> bool:
Expand All @@ -98,42 +73,11 @@ def is_failure_callback(self) -> bool:
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization

val = BaseSerialization.serialize(self.__dict__, strict=True)
return json.dumps(val)

@classmethod
def from_json(cls, json_str: str):
from airflow.serialization.serialized_objects import BaseSerialization

val = json.loads(json_str)
return cls(**BaseSerialization.deserialize(val))


class DagCallbackRequest(CallbackRequest):
"""
A Class with information about the success/failure DAG callback to be executed.
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param run_id: Run ID for the DagRun
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging
"""
"""A Class with information about the success/failure DAG callback to be executed."""

def __init__(
self,
full_filepath: str,
dag_id: str,
run_id: str,
processor_subdir: str | None,
is_failure_callback: bool | None = True,
msg: str | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
self.run_id = run_id
self.is_failure_callback = is_failure_callback
dag_id: str
run_id: str
is_failure_callback: bool | None = True
"""Flag to determine whether it is a Failure Callback or Success Callback"""
Loading

0 comments on commit 7940e39

Please sign in to comment.