Skip to content

Commit

Permalink
Add queues for triggerers.
Browse files Browse the repository at this point in the history
  • Loading branch information
avkirilishin committed Dec 20, 2023
1 parent e2393ee commit 55f9475
Show file tree
Hide file tree
Showing 16 changed files with 565 additions and 446 deletions.
3 changes: 2 additions & 1 deletion airflow/api_connexion/schemas/trigger_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class TriggerSchema(SQLAlchemySchema):
"""Sla Miss Schema."""
"""Trigger Schema."""

class Meta:
"""Meta."""
Expand All @@ -35,3 +35,4 @@ class Meta:
kwargs = auto_field(dump_only=True)
created_date = auto_field(dump_only=True)
triggerer_id = auto_field(dump_only=True)
queue = auto_field(dump_only=True)
6 changes: 6 additions & 0 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,11 @@ def string_lower_type(val):
type=positive_int(allow_zero=False),
help="The maximum number of triggers that a Triggerer will run at one time.",
)
ARG_QUEUES = Arg(
("-q", "--queues"),
help="Comma delimited list of queues to serve",
default=conf.get("triggerer", "default_queue"),
)

# reserialize
ARG_CLEAR_ONLY = Arg(
Expand Down Expand Up @@ -1966,6 +1971,7 @@ class GroupCommand(NamedTuple):
ARG_CAPACITY,
ARG_VERBOSE,
ARG_SKIP_SERVE_LOGS,
ARG_QUEUES,
),
),
ActionCommand(
Expand Down
10 changes: 7 additions & 3 deletions airflow/cli/commands/triggerer_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ def _serve_logs(skip_serve_logs: bool = False) -> Generator[None, None, None]:
sub_proc.terminate()


def triggerer_run(skip_serve_logs: bool, capacity: int, triggerer_heartrate: float):
def triggerer_run(skip_serve_logs: bool, capacity: int, queues: str, triggerer_heartrate: float):
with _serve_logs(skip_serve_logs):
triggerer_job_runner = TriggererJobRunner(job=Job(heartrate=triggerer_heartrate), capacity=capacity)
triggerer_job_runner = TriggererJobRunner(
job=Job(heartrate=triggerer_heartrate),
capacity=capacity,
queues=queues,
)
run_job(job=triggerer_job_runner.job, execute_callable=triggerer_job_runner._execute)


Expand All @@ -64,6 +68,6 @@ def triggerer(args):
run_command_with_daemon_option(
args=args,
process_name="triggerer",
callback=lambda: triggerer_run(args.skip_serve_logs, args.capacity, triggerer_heartrate),
callback=lambda: triggerer_run(args.skip_serve_logs, args.capacity, args.queues, triggerer_heartrate),
should_setup_logging=True,
)
7 changes: 7 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2389,6 +2389,13 @@ triggerer:
type: float
example: ~
default: "30"
default_queue:
description: |
Default queue that trigger get assigned to and that triggerer listen on.
version_added: 2.9.0
type: string
example: ~
default: "default"
kerberos:
description: ~
options:
Expand Down
15 changes: 13 additions & 2 deletions airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(
self,
job: Job,
capacity=None,
queues=None,
):
super().__init__(job)
if capacity is None:
Expand All @@ -260,6 +261,16 @@ def __init__(
else:
raise ValueError(f"Capacity number {capacity} is invalid")

err = f"Comma delimited list of queues {queues} is invalid"
if queues is None:
self.queues = {conf.get("triggerer", "default_queue")}
elif isinstance(queues, str) and len(queues) > 0:
self.queues = {s.strip() for s in queues.split(",")}
if not self.queues:
raise ValueError(err)
else:
raise ValueError(err)

self.health_check_threshold = conf.getint("triggerer", "triggerer_health_check_threshold")

should_queue = True
Expand Down Expand Up @@ -372,8 +383,8 @@ def _run_trigger_loop(self) -> None:

def load_triggers(self):
"""Query the database for the triggers we're supposed to be running and update the runner."""
Trigger.assign_unassigned(self.job.id, self.capacity, self.health_check_threshold)
ids = Trigger.ids_for_triggerer(self.job.id)
Trigger.assign_unassigned(self.job.id, self.capacity, self.queues, self.health_check_threshold)
ids = Trigger.ids_for_triggerer(self.job.id, self.queues)
self.trigger_runner.update_triggers(set(ids))

def handle_events(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Add the queue field to the trigger table.
Revision ID: e8b1fd3c7ccf
Revises: 10b52ebd31f7
Create Date: 2023-12-20 21:06:00.768844
"""

import sqlalchemy as sa
from alembic import op


# revision identifiers, used by Alembic.
revision = 'e8b1fd3c7ccf'
down_revision = '10b52ebd31f7'
branch_labels = None
depends_on = None
airflow_version = "2.9.0"


def upgrade():
"""Apply Add the queue field to the trigger table."""
with op.batch_alter_table("trigger") as batch_op:
try:
from airflow.configuration import conf

default_queue = conf.get("triggerer", "default_queue")
except: # noqa
default_queue = "default"

batch_op.add_column(
sa.Column(
"queue",
sa.String(256),
nullable=True,
server_default=default_queue,
)
)
batch_op.alter_column("queue", server_default=None)


def downgrade():
"""Unapply Add the queue field to the trigger table."""
with op.batch_alter_table("trigger") as batch_op:
batch_op.drop_column("queue", mssql_drop_default=True)
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2522,7 +2522,7 @@ def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
from airflow.models.trigger import Trigger

# First, make the trigger entry
trigger_row = Trigger.from_object(defer.trigger)
trigger_row = Trigger.from_object(defer.trigger, defer.trigger.queue)
session.add(trigger_row)
session.flush()

Expand Down
24 changes: 16 additions & 8 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Trigger(Base):
kwargs = Column(ExtendedJSON, nullable=False)
created_date = Column(UtcDateTime, nullable=False)
triggerer_id = Column(Integer, nullable=True)
queue = Column(String(256), nullable=True)

triggerer_job = relationship(
"Job",
Expand All @@ -79,19 +80,21 @@ def __init__(
self,
classpath: str,
kwargs: dict[str, Any],
queue: str,
created_date: datetime.datetime | None = None,
) -> None:
super().__init__()
self.classpath = classpath
self.kwargs = kwargs
self.queue = queue
self.created_date = created_date or timezone.utcnow()

@classmethod
@internal_api_call
def from_object(cls, trigger: BaseTrigger) -> Trigger:
def from_object(cls, trigger: BaseTrigger, queue: str) -> Trigger:
"""Alternative constructor that creates a trigger row based directly off of a Trigger object."""
classpath, kwargs = trigger.serialize()
return cls(classpath=classpath, kwargs=kwargs)
return cls(classpath=classpath, kwargs=kwargs, queue=queue)

@classmethod
@internal_api_call
Expand Down Expand Up @@ -197,15 +200,17 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) ->
@classmethod
@internal_api_call
@provide_session
def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]:
def ids_for_triggerer(cls, triggerer_id, queues: set, session: Session = NEW_SESSION) -> list[int]:
"""Retrieve a list of triggerer_ids."""
return session.scalars(select(cls.id).where(cls.triggerer_id == triggerer_id)).all()
return session.scalars(
select(cls.id).where(cls.triggerer_id == triggerer_id, cls.queue.in_(queues))
).all()

@classmethod
@internal_api_call
@provide_session
def assign_unassigned(
cls, triggerer_id, capacity, health_check_threshold, session: Session = NEW_SESSION
cls, triggerer_id, capacity, queues: set, health_check_threshold, session: Session = NEW_SESSION
) -> None:
"""
Assign unassigned triggers based on a number of conditions.
Expand Down Expand Up @@ -233,7 +238,7 @@ def assign_unassigned(
# Find triggers who do NOT have an alive triggerer_id, and then assign
# up to `capacity` of those to us.
trigger_ids_query = cls.get_sorted_triggers(
capacity=capacity, alive_triggerer_ids=alive_triggerer_ids, session=session
capacity=capacity, queues=queues, alive_triggerer_ids=alive_triggerer_ids, session=session
)
if trigger_ids_query:
session.execute(
Expand All @@ -246,11 +251,14 @@ def assign_unassigned(
session.commit()

@classmethod
def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
def get_sorted_triggers(cls, capacity, queues, alive_triggerer_ids, session):
query = with_row_locks(
select(cls.id)
.join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=False)
.where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)))
.where(
cls.queue.in_(queues),
or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)),
)
.order_by(coalesce(TaskInstance.priority_weight, 0).desc(), cls.created_date)
.limit(capacity),
session,
Expand Down
5 changes: 3 additions & 2 deletions airflow/sensors/time_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ class TimeSensorAsync(BaseSensorOperator):
:ref:`howto/operator:TimeSensorAsync`
"""

def __init__(self, *, target_time, **kwargs):
def __init__(self, *, target_time, trigger_queue=None, **kwargs):
super().__init__(**kwargs)
self.target_time = target_time
self.trigger_queue = trigger_queue

aware_time = timezone.coerce_datetime(
datetime.datetime.combine(datetime.datetime.today(), self.target_time, self.dag.timezone)
Expand All @@ -73,7 +74,7 @@ def __init__(self, *, target_time, **kwargs):
self.target_datetime = timezone.convert_to_utc(aware_time)

def execute(self, context: Context):
trigger = DateTimeTrigger(moment=self.target_datetime)
trigger = DateTimeTrigger(moment=self.target_datetime, queue=self.trigger_queue)
self.defer(
trigger=trigger,
method_name="execute_complete",
Expand Down
10 changes: 9 additions & 1 deletion airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import abc
from typing import Any, AsyncIterator

from airflow.configuration import conf
from airflow.utils.log.logging_mixin import LoggingMixin


Expand All @@ -36,7 +37,14 @@ class BaseTrigger(abc.ABC, LoggingMixin):
let them be re-instantiated elsewhere.
"""

def __init__(self, **kwargs):
def __init__(self, queue, **kwargs):
if queue is None:
self.queue = conf.get("triggerer", "default_queue")
elif isinstance(queue, str) and len(queue) > 0 and "," not in queue:
self.queue = queue
else:
raise ValueError(f"The trigger queue {queue} is invalid")

# these values are set by triggerer when preparing to run the instance
# when run, they are injected into logger record.
self.task_instance = None
Expand Down
8 changes: 4 additions & 4 deletions airflow/triggers/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class DateTimeTrigger(BaseTrigger):
The provided datetime MUST be in UTC.
"""

def __init__(self, moment: datetime.datetime):
super().__init__()
def __init__(self, moment: datetime.datetime, queue: str | None = None):
super().__init__(queue=queue)
if not isinstance(moment, datetime.datetime):
raise TypeError(f"Expected datetime.datetime type for moment. Got {type(moment)}")
# Make sure it's in UTC
Expand Down Expand Up @@ -84,5 +84,5 @@ class TimeDeltaTrigger(DateTimeTrigger):
DateTimeTrigger class, since they're operationally the same.
"""

def __init__(self, delta: datetime.timedelta):
super().__init__(moment=timezone.utcnow() + delta)
def __init__(self, delta: datetime.timedelta, queue: str | None = None):
super().__init__(moment=timezone.utcnow() + delta, queue=queue)
4 changes: 4 additions & 0 deletions airflow/www/static/js/dag/details/taskInstance/Details.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ const Details = ({ instance, group, dagId }: Props) => {
<Td>Trigger creation time</Td>
<Td>{`${apiTI?.trigger?.createdDate}`}</Td>
</Tr>
<Tr>
<Td>Trigger queue</Td>
<Td>{`${apiTI?.trigger?.queue}`}</Td>
</Tr>
<Tr>
<Td>Assigned triggerer</Td>
<Td>{`${apiTI?.triggererJob?.hostname}`}</Td>
Expand Down
1 change: 1 addition & 0 deletions airflow/www/static/js/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ interface TaskInstance {
interface Trigger {
classpath: string | null;
createdDate: string | null;
queue: string | null;
}

interface Job {
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
194706fc390025f473f73ce934bfe4b394b50ce76748e5df33ae643e38538357
f9c060db9564968af00b3c1e7901fe9d5dc4fa859b497b3bf3c801ce1e26b016
Loading

0 comments on commit 55f9475

Please sign in to comment.