Skip to content

Commit

Permalink
Allow executors to be specified with only the class name of the Execu…
Browse files Browse the repository at this point in the history
…tor (apache#40131)

* Allow executors to be specified with only the class name of the Executor

* Fix unit test by mocking loaded Executor, so it doesn't access the DB

* Update doc string to include passing class name as a valid way to specify an executor
  • Loading branch information
syedahsn authored Jun 8, 2024
1 parent 297ad80 commit 14a613f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
7 changes: 6 additions & 1 deletion airflow/executors/executor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
# executor may have both so we need two lookup dicts.
_alias_to_executors: dict[str, ExecutorName] = {}
_module_to_executors: dict[str, ExecutorName] = {}
_classname_to_executors: dict[str, ExecutorName] = {}
# Used to cache the computed ExecutorNames so that we don't need to read/parse config more than once
_executor_names: list[ExecutorName] = []
# Used to cache executors so that we don't construct executor objects unnecessarily
Expand Down Expand Up @@ -149,6 +150,7 @@ def _get_executor_names(cls) -> list[ExecutorName]:
_alias_to_executors[executor_name.alias] = executor_name
# All executors will have a module path
_module_to_executors[executor_name.module_path] = executor_name
_classname_to_executors[executor_name.module_path.split(".")[-1]] = executor_name
# Cache the executor names, so the logic of this method only runs once
_executor_names.append(executor_name)

Expand Down Expand Up @@ -201,6 +203,8 @@ def lookup_executor_name_by_str(cls, executor_name_str: str) -> ExecutorName:
return executor_name
elif executor_name := _module_to_executors.get(executor_name_str):
return executor_name
elif executor_name := _classname_to_executors.get(executor_name_str):
return executor_name
else:
raise AirflowException(f"Unknown executor being loaded: {executor_name_str}")

Expand All @@ -212,7 +216,8 @@ def load_executor(cls, executor_name: ExecutorName | str | None) -> BaseExecutor
This supports the following formats:
* by executor name for core executor
* by ``{plugin_name}.{class_name}`` for executor from plugins
* by import path.
* by import path
* by class name of the Executor
* by ExecutorName object specification
:return: an instance of executor class via executor_name
Expand Down
25 changes: 25 additions & 0 deletions tests/executors/test_executor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.executors import executor_loader
from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, ExecutorName
from airflow.executors.local_executor import LocalExecutor
from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from tests.test_utils.config import conf_vars

Expand Down Expand Up @@ -334,3 +335,27 @@ def test_load_executor_alias(self):
assert isinstance(
ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor
)

@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor", autospec=True)
def test_load_custom_executor_with_classname(self, mock_executor):
with patch.object(ExecutorLoader, "block_use_of_hybrid_exec"):
with conf_vars(
{
(
"core",
"executor",
): "my_alias:airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
}
):
ExecutorLoader.init_executors()
assert isinstance(ExecutorLoader.load_executor("my_alias"), AwsEcsExecutor)
assert isinstance(ExecutorLoader.load_executor("AwsEcsExecutor"), AwsEcsExecutor)
assert isinstance(
ExecutorLoader.load_executor(
"airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"
),
AwsEcsExecutor,
)
assert isinstance(
ExecutorLoader.load_executor(executor_loader._executor_names[0]), AwsEcsExecutor
)

0 comments on commit 14a613f

Please sign in to comment.