diff --git a/airflow/api_connexion/schemas/trigger_schema.py b/airflow/api_connexion/schemas/trigger_schema.py index 15d180a5732ff..ad527f2110c92 100644 --- a/airflow/api_connexion/schemas/trigger_schema.py +++ b/airflow/api_connexion/schemas/trigger_schema.py @@ -23,7 +23,7 @@ class TriggerSchema(SQLAlchemySchema): - """Sla Miss Schema.""" + """Trigger Schema.""" class Meta: """Meta.""" @@ -35,3 +35,4 @@ class Meta: kwargs = auto_field(dump_only=True) created_date = auto_field(dump_only=True) triggerer_id = auto_field(dump_only=True) + queue = auto_field(dump_only=True) diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index e5a57f4552b3d..106c9e1925f4e 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -978,6 +978,11 @@ def string_lower_type(val): type=positive_int(allow_zero=False), help="The maximum number of triggers that a Triggerer will run at one time.", ) +ARG_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve", + default=conf.get("triggerer", "default_queue"), +) # reserialize ARG_CLEAR_ONLY = Arg( @@ -1966,6 +1971,7 @@ class GroupCommand(NamedTuple): ARG_CAPACITY, ARG_VERBOSE, ARG_SKIP_SERVE_LOGS, + ARG_QUEUES, ), ), ActionCommand( diff --git a/airflow/cli/commands/triggerer_command.py b/airflow/cli/commands/triggerer_command.py index 3479480dbf8ac..92b6873e0357c 100644 --- a/airflow/cli/commands/triggerer_command.py +++ b/airflow/cli/commands/triggerer_command.py @@ -47,9 +47,13 @@ def _serve_logs(skip_serve_logs: bool = False) -> Generator[None, None, None]: sub_proc.terminate() -def triggerer_run(skip_serve_logs: bool, capacity: int, triggerer_heartrate: float): +def triggerer_run(skip_serve_logs: bool, capacity: int, queues: str, triggerer_heartrate: float): with _serve_logs(skip_serve_logs): - triggerer_job_runner = TriggererJobRunner(job=Job(heartrate=triggerer_heartrate), capacity=capacity) + triggerer_job_runner = TriggererJobRunner( + job=Job(heartrate=triggerer_heartrate), + capacity=capacity, + queues=queues, + ) run_job(job=triggerer_job_runner.job, execute_callable=triggerer_job_runner._execute) @@ -64,6 +68,6 @@ def triggerer(args): run_command_with_daemon_option( args=args, process_name="triggerer", - callback=lambda: triggerer_run(args.skip_serve_logs, args.capacity, triggerer_heartrate), + callback=lambda: triggerer_run(args.skip_serve_logs, args.capacity, args.queues, triggerer_heartrate), should_setup_logging=True, ) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 0930fada76714..79011e4295af7 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -2389,6 +2389,13 @@ triggerer: type: float example: ~ default: "30" + default_queue: + description: | + Default queue that trigger get assigned to and that triggerer listen on. + version_added: 2.9.0 + type: string + example: ~ + default: "default" kerberos: description: ~ options: diff --git a/airflow/jobs/triggerer_job_runner.py b/airflow/jobs/triggerer_job_runner.py index bb151b32cc87e..356d952c85422 100644 --- a/airflow/jobs/triggerer_job_runner.py +++ b/airflow/jobs/triggerer_job_runner.py @@ -251,6 +251,7 @@ def __init__( self, job: Job, capacity=None, + queues=None, ): super().__init__(job) if capacity is None: @@ -260,6 +261,16 @@ def __init__( else: raise ValueError(f"Capacity number {capacity} is invalid") + err = f"Comma delimited list of queues {queues} is invalid" + if queues is None: + self.queues = {conf.get("triggerer", "default_queue")} + elif isinstance(queues, str) and len(queues) > 0: + self.queues = {s.strip() for s in queues.split(",")} + if not self.queues: + raise ValueError(err) + else: + raise ValueError(err) + self.health_check_threshold = conf.getint("triggerer", "triggerer_health_check_threshold") should_queue = True @@ -372,8 +383,8 @@ def _run_trigger_loop(self) -> None: def load_triggers(self): """Query the database for the triggers we're supposed to be running and update the runner.""" - Trigger.assign_unassigned(self.job.id, self.capacity, self.health_check_threshold) - ids = Trigger.ids_for_triggerer(self.job.id) + Trigger.assign_unassigned(self.job.id, self.capacity, self.queues, self.health_check_threshold) + ids = Trigger.ids_for_triggerer(self.job.id, self.queues) self.trigger_runner.update_triggers(set(ids)) def handle_events(self): diff --git a/airflow/migrations/versions/0133_2_9_0_add_the_queue_field_to_the_trigger_table.py b/airflow/migrations/versions/0133_2_9_0_add_the_queue_field_to_the_trigger_table.py new file mode 100644 index 0000000000000..fe3baf8b85e42 --- /dev/null +++ b/airflow/migrations/versions/0133_2_9_0_add_the_queue_field_to_the_trigger_table.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add the queue field to the trigger table. + +Revision ID: e8b1fd3c7ccf +Revises: 10b52ebd31f7 +Create Date: 2023-12-20 21:06:00.768844 + +""" + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = 'e8b1fd3c7ccf' +down_revision = '10b52ebd31f7' +branch_labels = None +depends_on = None +airflow_version = "2.9.0" + + +def upgrade(): + """Apply Add the queue field to the trigger table.""" + with op.batch_alter_table("trigger") as batch_op: + try: + from airflow.configuration import conf + + default_queue = conf.get("triggerer", "default_queue") + except: # noqa + default_queue = "default" + + batch_op.add_column( + sa.Column( + "queue", + sa.String(256), + nullable=True, + server_default=default_queue, + ) + ) + batch_op.alter_column("queue", server_default=None) + + +def downgrade(): + """Unapply Add the queue field to the trigger table.""" + with op.batch_alter_table("trigger") as batch_op: + batch_op.drop_column("queue", mssql_drop_default=True) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ae6b1e35c1276..228194d96792b 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2522,7 +2522,7 @@ def _defer_task(self, session: Session, defer: TaskDeferred) -> None: from airflow.models.trigger import Trigger # First, make the trigger entry - trigger_row = Trigger.from_object(defer.trigger) + trigger_row = Trigger.from_object(defer.trigger, defer.trigger.queue) session.add(trigger_row) session.flush() diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 4ad42b17b8fc7..932b3080a3f35 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -65,6 +65,7 @@ class Trigger(Base): kwargs = Column(ExtendedJSON, nullable=False) created_date = Column(UtcDateTime, nullable=False) triggerer_id = Column(Integer, nullable=True) + queue = Column(String(256), nullable=True) triggerer_job = relationship( "Job", @@ -79,19 +80,21 @@ def __init__( self, classpath: str, kwargs: dict[str, Any], + queue: str, created_date: datetime.datetime | None = None, ) -> None: super().__init__() self.classpath = classpath self.kwargs = kwargs + self.queue = queue self.created_date = created_date or timezone.utcnow() @classmethod @internal_api_call - def from_object(cls, trigger: BaseTrigger) -> Trigger: + def from_object(cls, trigger: BaseTrigger, queue: str) -> Trigger: """Alternative constructor that creates a trigger row based directly off of a Trigger object.""" classpath, kwargs = trigger.serialize() - return cls(classpath=classpath, kwargs=kwargs) + return cls(classpath=classpath, kwargs=kwargs, queue=queue) @classmethod @internal_api_call @@ -197,15 +200,17 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> @classmethod @internal_api_call @provide_session - def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]: + def ids_for_triggerer(cls, triggerer_id, queues: set, session: Session = NEW_SESSION) -> list[int]: """Retrieve a list of triggerer_ids.""" - return session.scalars(select(cls.id).where(cls.triggerer_id == triggerer_id)).all() + return session.scalars( + select(cls.id).where(cls.triggerer_id == triggerer_id, cls.queue.in_(queues)) + ).all() @classmethod @internal_api_call @provide_session def assign_unassigned( - cls, triggerer_id, capacity, health_check_threshold, session: Session = NEW_SESSION + cls, triggerer_id, capacity, queues: set, health_check_threshold, session: Session = NEW_SESSION ) -> None: """ Assign unassigned triggers based on a number of conditions. @@ -233,7 +238,7 @@ def assign_unassigned( # Find triggers who do NOT have an alive triggerer_id, and then assign # up to `capacity` of those to us. trigger_ids_query = cls.get_sorted_triggers( - capacity=capacity, alive_triggerer_ids=alive_triggerer_ids, session=session + capacity=capacity, queues=queues, alive_triggerer_ids=alive_triggerer_ids, session=session ) if trigger_ids_query: session.execute( @@ -246,11 +251,14 @@ def assign_unassigned( session.commit() @classmethod - def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session): + def get_sorted_triggers(cls, capacity, queues, alive_triggerer_ids, session): query = with_row_locks( select(cls.id) .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=False) - .where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids))) + .where( + cls.queue.in_(queues), + or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)), + ) .order_by(coalesce(TaskInstance.priority_weight, 0).desc(), cls.created_date) .limit(capacity), session, diff --git a/airflow/sensors/time_sensor.py b/airflow/sensors/time_sensor.py index cc07323ca1da9..90e426cd95706 100644 --- a/airflow/sensors/time_sensor.py +++ b/airflow/sensors/time_sensor.py @@ -62,9 +62,10 @@ class TimeSensorAsync(BaseSensorOperator): :ref:`howto/operator:TimeSensorAsync` """ - def __init__(self, *, target_time, **kwargs): + def __init__(self, *, target_time, trigger_queue=None, **kwargs): super().__init__(**kwargs) self.target_time = target_time + self.trigger_queue = trigger_queue aware_time = timezone.coerce_datetime( datetime.datetime.combine(datetime.datetime.today(), self.target_time, self.dag.timezone) @@ -73,7 +74,7 @@ def __init__(self, *, target_time, **kwargs): self.target_datetime = timezone.convert_to_utc(aware_time) def execute(self, context: Context): - trigger = DateTimeTrigger(moment=self.target_datetime) + trigger = DateTimeTrigger(moment=self.target_datetime, queue=self.trigger_queue) self.defer( trigger=trigger, method_name="execute_complete", diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index 0d239af0cafd4..5457f549e6421 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -19,6 +19,7 @@ import abc from typing import Any, AsyncIterator +from airflow.configuration import conf from airflow.utils.log.logging_mixin import LoggingMixin @@ -36,7 +37,14 @@ class BaseTrigger(abc.ABC, LoggingMixin): let them be re-instantiated elsewhere. """ - def __init__(self, **kwargs): + def __init__(self, queue, **kwargs): + if queue is None: + self.queue = conf.get("triggerer", "default_queue") + elif isinstance(queue, str) and len(queue) > 0 and "," not in queue: + self.queue = queue + else: + raise ValueError(f"The trigger queue {queue} is invalid") + # these values are set by triggerer when preparing to run the instance # when run, they are injected into logger record. self.task_instance = None diff --git a/airflow/triggers/temporal.py b/airflow/triggers/temporal.py index 18bdd80bff385..296b6aeedc610 100644 --- a/airflow/triggers/temporal.py +++ b/airflow/triggers/temporal.py @@ -34,8 +34,8 @@ class DateTimeTrigger(BaseTrigger): The provided datetime MUST be in UTC. """ - def __init__(self, moment: datetime.datetime): - super().__init__() + def __init__(self, moment: datetime.datetime, queue: str | None = None): + super().__init__(queue=queue) if not isinstance(moment, datetime.datetime): raise TypeError(f"Expected datetime.datetime type for moment. Got {type(moment)}") # Make sure it's in UTC @@ -84,5 +84,5 @@ class TimeDeltaTrigger(DateTimeTrigger): DateTimeTrigger class, since they're operationally the same. """ - def __init__(self, delta: datetime.timedelta): - super().__init__(moment=timezone.utcnow() + delta) + def __init__(self, delta: datetime.timedelta, queue: str | None = None): + super().__init__(moment=timezone.utcnow() + delta, queue=queue) diff --git a/airflow/www/static/js/dag/details/taskInstance/Details.tsx b/airflow/www/static/js/dag/details/taskInstance/Details.tsx index 5cf1c5cfb99b5..a1d24eca01038 100644 --- a/airflow/www/static/js/dag/details/taskInstance/Details.tsx +++ b/airflow/www/static/js/dag/details/taskInstance/Details.tsx @@ -99,6 +99,10 @@ const Details = ({ instance, group, dagId }: Props) => {