From 76ca5f9689fb7ecaad0b87118af08cf70c069319 Mon Sep 17 00:00:00 2001 From: GPK Date: Fri, 27 Sep 2024 14:24:42 +0100 Subject: [PATCH] Pre commit script to validate template fields (#42284) --- .pre-commit-config.yaml | 7 + contributing-docs/08_static_code_checks.rst | 2 + .../doc/images/output_static-checks.svg | 12 +- .../doc/images/output_static-checks.txt | 2 +- .../src/airflow_breeze/pre_commit_ids.py | 1 + .../pre_commit/check_provider_yaml_files.py | 15 +- .../ci/pre_commit/check_template_fields.py | 40 ++++ .../ci/pre_commit/common_precommit_utils.py | 17 ++ scripts/ci/pre_commit/migration_reference.py | 14 +- scripts/ci/pre_commit/update_er_diagram.py | 13 +- .../ci/pre_commit/update_fastapi_api_spec.py | 13 +- .../in_container/run_template_fields_check.py | 180 ++++++++++++++++++ 12 files changed, 279 insertions(+), 37 deletions(-) create mode 100755 scripts/ci/pre_commit/check_template_fields.py create mode 100644 scripts/in_container/run_template_fields_check.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 942b34ca2e6d5..2263086335bc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1343,6 +1343,13 @@ repos: files: ^airflow/providers/.*/provider\.yaml$ additional_dependencies: ['rich>=12.4.4'] require_serial: true + - id: check-template-fields-valid + name: Check templated fields mapped in operators/sensors + language: python + entry: ./scripts/ci/pre_commit/check_template_fields.py + files: ^airflow/.*/sensors/.*\.py$|^airflow/.*/operators/.*\.py$ + additional_dependencies: [ 'rich>=12.4.4' ] + require_serial: true - id: update-migration-references name: Update migration ref doc language: python diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index 0a3dcacd9e070..d50b9db3e607f 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -236,6 +236,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------+---------+ | check-template-context-variable-in-sync | Sync template context variable refs | | +-----------------------------------------------------------+--------------------------------------------------------+---------+ +| check-template-fields-valid | Check templated fields mapped in operators/sensors | * | ++-----------------------------------------------------------+--------------------------------------------------------+---------+ | check-tests-in-the-right-folders | Check if tests are in the right folders | | +-----------------------------------------------------------+--------------------------------------------------------+---------+ | check-tests-unittest-testcase | Unit tests do not inherit from unittest.TestCase | | diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg index ed52a596def64..36b88513a56ba 100644 --- a/dev/breeze/doc/images/output_static-checks.svg +++ b/dev/breeze/doc/images/output_static-checks.svg @@ -356,12 +356,12 @@ check-safe-filter-usage-in-html | check-sql-dependency-common-data-structure |    check-start-date-not-used-in-defaults | check-system-tests-present |              check-system-tests-tocs | check-taskinstance-tis-attrs |                          -check-template-context-variable-in-sync | check-tests-in-the-right-folders |      -check-tests-unittest-testcase | check-urlparse-usage-in-code |                    -check-usage-of-re2-over-re | check-xml | codespell | compile-ui-assets |          -compile-ui-assets-dev | compile-www-assets | compile-www-assets-dev |             -create-missing-init-py-files-tests | debug-statements | detect-private-key |      -doctoc | end-of-file-fixer | fix-encoding-pragma | flynt |                        +check-template-context-variable-in-sync | check-template-fields-valid |           +check-tests-in-the-right-folders | check-tests-unittest-testcase |                +check-urlparse-usage-in-code | check-usage-of-re2-over-re | check-xml | codespell +| compile-ui-assets | compile-ui-assets-dev | compile-www-assets |                +compile-www-assets-dev | create-missing-init-py-files-tests | debug-statements |  +detect-private-key | doctoc | end-of-file-fixer | fix-encoding-pragma | flynt |   generate-airflow-diagrams | generate-openapi-spec | generate-pypi-readme |        identity | insert-license | kubeconform | lint-chart-schema | lint-css |          lint-dockerfile | lint-helm-chart | lint-json-schema | lint-markdown |            diff --git a/dev/breeze/doc/images/output_static-checks.txt b/dev/breeze/doc/images/output_static-checks.txt index 3a3837fbb15bb..9e3ae46130640 100644 --- a/dev/breeze/doc/images/output_static-checks.txt +++ b/dev/breeze/doc/images/output_static-checks.txt @@ -1 +1 @@ -5c6ba60b1865538bce04fc940cd240c6 +e33cdf5f43d8c63290e44e92dc19d2c4 diff --git a/dev/breeze/src/airflow_breeze/pre_commit_ids.py b/dev/breeze/src/airflow_breeze/pre_commit_ids.py index 9a48df5e3f69c..457379f5b90ba 100644 --- a/dev/breeze/src/airflow_breeze/pre_commit_ids.py +++ b/dev/breeze/src/airflow_breeze/pre_commit_ids.py @@ -83,6 +83,7 @@ "check-system-tests-tocs", "check-taskinstance-tis-attrs", "check-template-context-variable-in-sync", + "check-template-fields-valid", "check-tests-in-the-right-folders", "check-tests-unittest-testcase", "check-urlparse-usage-in-code", diff --git a/scripts/ci/pre_commit/check_provider_yaml_files.py b/scripts/ci/pre_commit/check_provider_yaml_files.py index fcbe2512910a3..f848e38afa0b2 100755 --- a/scripts/ci/pre_commit/check_provider_yaml_files.py +++ b/scripts/ci/pre_commit/check_provider_yaml_files.py @@ -17,12 +17,15 @@ # under the License. from __future__ import annotations -import os import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -33,10 +36,4 @@ warn_image_upgrade_needed=True, extra_env={"PYTHONWARNINGS": "default"}, ) -if cmd_result.returncode != 0 and os.environ.get("CI") != "true": - console.print( - "\n[yellow]If you see strange stacktraces above, especially about missing imports " - "run this command:[/]\n" - ) - console.print("[magenta]breeze ci-image build --python 3.8 --upgrade-to-newer-dependencies[/]\n") -sys.exit(cmd_result.returncode) +validate_cmd_result(cmd_result, include_ci_env_check=True) diff --git a/scripts/ci/pre_commit/check_template_fields.py b/scripts/ci/pre_commit/check_template_fields.py new file mode 100755 index 0000000000000..da0b60fbd978f --- /dev/null +++ b/scripts/ci/pre_commit/check_template_fields.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.resolve())) +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) + +initialize_breeze_precommit(__name__, __file__) +py_files_to_test = sys.argv[1:] + +cmd_result = run_command_via_breeze_shell( + ["python3", "/opt/airflow/scripts/in_container/run_template_fields_check.py", *py_files_to_test], + backend="sqlite", + warn_image_upgrade_needed=True, + extra_env={"PYTHONWARNINGS": "default"}, +) + +validate_cmd_result(cmd_result, include_ci_env_check=True) diff --git a/scripts/ci/pre_commit/common_precommit_utils.py b/scripts/ci/pre_commit/common_precommit_utils.py index 41bc3a5eeaf93..4f62c50cabeaa 100644 --- a/scripts/ci/pre_commit/common_precommit_utils.py +++ b/scripts/ci/pre_commit/common_precommit_utils.py @@ -211,3 +211,20 @@ def check_list_sorted(the_list: list[str], message: str, errors: list[str]) -> b console.print() errors.append(f"ERROR in {message}. The elements are not sorted/unique.") return False + + +def validate_cmd_result(cmd_result, include_ci_env_check=False): + if include_ci_env_check: + if cmd_result.returncode != 0 and os.environ.get("CI") != "true": + console.print( + "\n[yellow]If you see strange stacktraces above, especially about missing imports " + "run this command:[/]\n" + ) + console.print("[magenta]breeze ci-image build --python 3.8 --upgrade-to-newer-dependencies[/]\n") + + elif cmd_result.returncode != 0: + console.print( + "[warning]\nIf you see strange stacktraces above, " + "run `breeze ci-image build --python 3.8` and try again." + ) + sys.exit(cmd_result.returncode) diff --git a/scripts/ci/pre_commit/migration_reference.py b/scripts/ci/pre_commit/migration_reference.py index 34d3a94c6a90d..505bea5ca91af 100755 --- a/scripts/ci/pre_commit/migration_reference.py +++ b/scripts/ci/pre_commit/migration_reference.py @@ -21,7 +21,11 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -29,9 +33,5 @@ ["python3", "/opt/airflow/scripts/in_container/run_migration_reference.py"], backend="sqlite", ) -if cmd_result.returncode != 0: - console.print( - "[warning]\nIf you see strange stacktraces above, " - "run `breeze ci-image build --python 3.8` and try again." - ) -sys.exit(cmd_result.returncode) + +validate_cmd_result(cmd_result) diff --git a/scripts/ci/pre_commit/update_er_diagram.py b/scripts/ci/pre_commit/update_er_diagram.py index e660b47c6e6ae..c4f3cb797cf21 100755 --- a/scripts/ci/pre_commit/update_er_diagram.py +++ b/scripts/ci/pre_commit/update_er_diagram.py @@ -21,7 +21,11 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -36,9 +40,4 @@ }, ) -if cmd_result.returncode != 0: - console.print( - "[warning]\nIf you see strange stacktraces above, " - "run `breeze ci-image build --python 3.8` and try again." - ) - sys.exit(cmd_result.returncode) +validate_cmd_result(cmd_result) diff --git a/scripts/ci/pre_commit/update_fastapi_api_spec.py b/scripts/ci/pre_commit/update_fastapi_api_spec.py index 15ccaa5ac209e..3d7731c7ef2e2 100755 --- a/scripts/ci/pre_commit/update_fastapi_api_spec.py +++ b/scripts/ci/pre_commit/update_fastapi_api_spec.py @@ -21,7 +21,11 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) -from common_precommit_utils import console, initialize_breeze_precommit, run_command_via_breeze_shell +from common_precommit_utils import ( + initialize_breeze_precommit, + run_command_via_breeze_shell, + validate_cmd_result, +) initialize_breeze_precommit(__name__, __file__) @@ -31,9 +35,4 @@ skip_environment_initialization=False, ) -if cmd_result.returncode != 0: - console.print( - "[warning]\nIf you see strange stacktraces above, " - "run `breeze ci-image build --python 3.8` and try again." - ) -sys.exit(cmd_result.returncode) +validate_cmd_result(cmd_result) diff --git a/scripts/in_container/run_template_fields_check.py b/scripts/in_container/run_template_fields_check.py new file mode 100644 index 0000000000000..202dce35c5745 --- /dev/null +++ b/scripts/in_container/run_template_fields_check.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import importlib.util +import inspect +import itertools +import pathlib +import sys +import warnings + +import yaml +from rich.console import Console + +try: + from yaml import CSafeLoader as SafeLoader +except ImportError: + from yaml import SafeLoader # type: ignore + +console = Console(width=400, color_system="standard") +ROOT_DIR = pathlib.Path(__file__).resolve().parents[2] + +provider_files_pattern = pathlib.Path(ROOT_DIR, "airflow", "providers").rglob("provider.yaml") +errors: list[str] = [] + +OPERATORS: list[str] = ["sensors", "operators"] +CLASS_IDENTIFIERS: list[str] = ["sensor", "operator"] + +TEMPLATE_TYPES: list[str] = ["template_fields"] + + +class InstanceFieldExtractor(ast.NodeVisitor): + def __init__(self): + self.current_class = None + self.instance_fields = [] + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + if node.name == "__init__": + self.generic_visit(node) + return node + + def visit_Assign(self, node: ast.Assign) -> ast.Assign: + fields = [] + for target in node.targets: + if isinstance(target, ast.Attribute): + fields.append(target.attr) + if fields: + self.instance_fields.extend(fields) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: + if isinstance(node.target, ast.Attribute): + self.instance_fields.append(node.target.attr) + return node + + +def get_template_fields_and_class_instance_fields(cls): + """ + 1.This method retrieves the operator class and obtains all its parent classes using the method resolution order (MRO). + 2. It then gathers the templated fields declared in both the operator class and its parent classes. + 3. Finally, it retrieves the instance fields of the operator class, specifically the self.fields attributes. + """ + all_template_fields = [] + class_instance_fields = [] + + all_classes = cls.__mro__ + for current_class in all_classes: + if current_class.__init__ is not object.__init__: + cls_attr = current_class.__dict__ + for template_type in TEMPLATE_TYPES: + fields = cls_attr.get(template_type) + if fields: + all_template_fields.extend(fields) + + tree = ast.parse(inspect.getsource(current_class)) + visitor = InstanceFieldExtractor() + visitor.visit(tree) + if visitor.instance_fields: + class_instance_fields.extend(visitor.instance_fields) + return all_template_fields, class_instance_fields + + +def load_yaml_data() -> dict: + """ + It loads all the provider YAML files and retrieves the module referenced within each YAML file. + """ + package_paths = sorted(str(path) for path in provider_files_pattern) + result = {} + for provider_yaml_path in package_paths: + with open(provider_yaml_path) as yaml_file: + provider = yaml.load(yaml_file, SafeLoader) + rel_path = pathlib.Path(provider_yaml_path).relative_to(ROOT_DIR).as_posix() + result[rel_path] = provider + return result + + +def get_providers_modules() -> list[str]: + modules_container = [] + result = load_yaml_data() + + for (_, provider_data), resource_type in itertools.product(result.items(), OPERATORS): + if provider_data.get(resource_type): + for data in provider_data.get(resource_type): + modules_container.extend(data.get("python-modules")) + + return modules_container + + +def is_class_eligible(name: str) -> bool: + for op in CLASS_IDENTIFIERS: + if name.lower().endswith(op): + return True + return False + + +def get_eligible_classes(all_classes): + """ + Filter the results to include only classes that end with `Sensor` or `Operator`. + + """ + + eligible_classes = [(name, cls) for name, cls in all_classes if is_class_eligible(name)] + return eligible_classes + + +def iter_check_template_fields(module: str): + """ + 1. This method imports the providers module and retrieves all the classes defined within it. + 2. It then filters and selects classes related to operators or sensors by checking if the class name ends with "Operator" or "Sensor." + 3. For each operator class, it validates the template fields by inspecting the class instance fields. + """ + with warnings.catch_warnings(record=True): + imported_module = importlib.import_module(module) + classes = inspect.getmembers(imported_module, inspect.isclass) + op_classes = get_eligible_classes(classes) + + for op_class_name, cls in op_classes: + if cls.__module__ == module: + templated_fields, class_instance_fields = get_template_fields_and_class_instance_fields(cls) + + for field in templated_fields: + if field not in class_instance_fields: + errors.append(f"{module}: {op_class_name}: {field}") + + +if __name__ == "__main__": + provider_modules = get_providers_modules() + + if len(sys.argv) > 1: + py_files = sorted(sys.argv[1:]) + modules_to_validate = [ + module_name + for pyfile in py_files + if (module_name := pyfile.rstrip(".py").replace("/", ".")) in provider_modules + ] + else: + modules_to_validate = provider_modules + + [iter_check_template_fields(module) for module in modules_to_validate] + if errors: + console.print("[red]Found Invalid template fields:") + for error in errors: + console.print(f"[red]Error:[/] {error}") + + sys.exit(len(errors))