Skip to content

Commit

Permalink
Fix bug in task_sdk's parse function (apache#44056)
Browse files Browse the repository at this point in the history
This function was added by not used or tested in the first PR that added the
task_runner code, and in the final throes of that PR we swapped from msgspec
to pydantic, and in doing so introduced a runtime error from Pydantic as it
tried to look at the type hints of the `task: BaseOperator` property

The fix here is to call model_construct to skip pydantic validations, which is
safe here as the TI (which RuntimeTI inherits from) was validated when the
StartupDetails object was parsed+created.

And this time add tests for the function too.

In order to test this I have created the first very simple test dag in the
SDK and configured pytest to skip that entire directory when looking for
tests.
  • Loading branch information
ashb authored Nov 15, 2024
1 parent 123dadd commit 6e59137
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 4 deletions.
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ def allow_future_exec_dates(self) -> bool:

def resolve_template_files(self):
for t in self.tasks:
t.resolve_template_files()
# TODO: TaskSDK: move this on to BaseOperator and remove the check?
if hasattr(t, "resolve_template_files"):
t.resolve_template_files()

def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment:
"""Build a Jinja2 environment."""
Expand Down
5 changes: 3 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class RuntimeTaskInstance(TaskInstance):

def parse(what: StartupDetails) -> RuntimeTaskInstance:
# TODO: Task-SDK:
# Using DagBag here is aoubt 98% wrong, but it'll do for now
# Using DagBag here is about 98% wrong, but it'll do for now

from airflow.models.dagbag import DagBag

Expand All @@ -64,7 +64,8 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
task = dag.task_dict[what.ti.task_id]
if not isinstance(task, BaseOperator):
raise TypeError(f"task is of the wrong type, got {type(task)}, wanted {BaseOperator}")
return RuntimeTaskInstance(**what.ti.model_dump(exclude_unset=True), task=task)

return RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=task)


@attrs.define()
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, NoReturn

import pytest
Expand All @@ -43,6 +44,9 @@ def pytest_addhooks(pluginmanager: pytest.PytestPluginManager):
def pytest_configure(config: pytest.Config) -> None:
config.inicfg["airflow_deprecations_ignore"] = []

# Always skip looking for tests in these folders!
config.addinivalue_line("norecursedirs", "tests/test_dags")


class LogCapture:
# Like structlog.typing.LogCapture, but that doesn't add log_level in to the event dict
Expand All @@ -62,6 +66,11 @@ def __call__(self, _: WrappedLogger, method_name: str, event_dict: EventDict) ->
raise DropEvent


@pytest.fixture
def test_dags_dir():
return Path(__file__).parent.joinpath("dags")


@pytest.fixture
def captured_logs():
import structlog
Expand Down
28 changes: 28 additions & 0 deletions task_sdk/tests/dags/super_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.sdk import BaseOperator, dag


@dag()
def super_basic():
BaseOperator(task_id="a")


super_basic()
18 changes: 17 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
from __future__ import annotations

import uuid
from pathlib import Path
from socket import socketpair

import pytest
from uuid6 import uuid7

from airflow.sdk.api.datamodels.ti import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails
from airflow.sdk.execution_time.task_runner import CommsDecoder
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse


class TestCommsDecoder:
Expand Down Expand Up @@ -54,3 +57,16 @@ def test_recv_StartupDetails(self):
assert decoder.request_socket is not None
assert decoder.request_socket.writable()
assert decoder.request_socket.fileno() == w2.fileno()


def test_parse(test_dags_dir: Path):
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1),
file=str(test_dags_dir / "super_basic.py"),
requests_fd=0,
)

ti = parse(what)

assert ti.task
assert ti.task.dag

0 comments on commit 6e59137

Please sign in to comment.