Skip to content

Commit

Permalink
Nicer pyproject.toml dependency generation (apache#37114)
Browse files Browse the repository at this point in the history
The pyproject.toml generated dependencies from providers are now
generated in a bit more streamlined way:

* the "empty" dependencies are now single-line empty arrays, to make
  github renderer of pyproject.toml happied (especially when showing
  diff in preview

* instead of calculating hashes and preventing generation of
  pyproject.toml, we now always generate it when the pre-commit is run.
  This is possible because generation is stable and produces always the
  same results from the same input, so we can safely regenerate the file
  in CI with `--all-files` and the file will not be changed. This way
  we avoid Hash collision when we have parallel changes coming to
  different providers.
  • Loading branch information
potiuk authored Jan 31, 2024
1 parent 4150314 commit e11a111
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 58 deletions.
19 changes: 6 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,6 @@ winrm = [
# If you want to modify these - modify the corresponding provider.yaml instead.
#############################################################################################################
# START OF GENERATED DEPENDENCIES
# Hash of dependencies: ad91a0758ca9b408679bd3ea3ec22c66
airbyte = [ # source: airflow/providers/airbyte/provider.yaml
"apache-airflow[http]",
]
Expand Down Expand Up @@ -596,8 +595,7 @@ apache-livy = [ # source: airflow/providers/apache/livy/provider.yaml
"apache-airflow[http]",
"asgiref",
]
apache-pig = [ # source: airflow/providers/apache/pig/provider.yaml
]
apache-pig = [] # source: airflow/providers/apache/pig/provider.yaml
apache-pinot = [ # source: airflow/providers/apache/pinot/provider.yaml
"apache-airflow[common_sql]",
"pinotdb>0.4.7",
Expand Down Expand Up @@ -638,8 +636,7 @@ cncf-kubernetes = [ # source: airflow/providers/cncf/kubernetes/provider.yaml
cohere = [ # source: airflow/providers/cohere/provider.yaml
"cohere>=4.37",
]
common-io = [ # source: airflow/providers/common/io/provider.yaml
]
common-io = [] # source: airflow/providers/common/io/provider.yaml
common-sql = [ # source: airflow/providers/common/sql/provider.yaml
"sqlparse>=0.4.2",
]
Expand Down Expand Up @@ -687,8 +684,7 @@ fab = [ # source: airflow/providers/fab/provider.yaml
facebook = [ # source: airflow/providers/facebook/provider.yaml
"facebook-business>=6.0.2",
]
ftp = [ # source: airflow/providers/ftp/provider.yaml
]
ftp = [] # source: airflow/providers/ftp/provider.yaml
github = [ # source: airflow/providers/github/provider.yaml
"PyGithub!=1.58",
]
Expand Down Expand Up @@ -766,8 +762,7 @@ http = [ # source: airflow/providers/http/provider.yaml
"requests>=2.26.0",
"requests_toolbelt",
]
imap = [ # source: airflow/providers/imap/provider.yaml
]
imap = [] # source: airflow/providers/imap/provider.yaml
influxdb = [ # source: airflow/providers/influxdb/provider.yaml
"influxdb-client>=1.19.0",
"requests>=2.26.0",
Expand Down Expand Up @@ -835,8 +830,7 @@ odbc = [ # source: airflow/providers/odbc/provider.yaml
openai = [ # source: airflow/providers/openai/provider.yaml
"openai[datalib]>=1.0",
]
openfaas = [ # source: airflow/providers/openfaas/provider.yaml
]
openfaas = [] # source: airflow/providers/openfaas/provider.yaml
openlineage = [ # source: airflow/providers/openlineage/provider.yaml
"apache-airflow[common_sql]",
"attrs>=22.2",
Expand Down Expand Up @@ -904,8 +898,7 @@ slack = [ # source: airflow/providers/slack/provider.yaml
"apache-airflow[common_sql]",
"slack_sdk>=3.19.0",
]
smtp = [ # source: airflow/providers/smtp/provider.yaml
]
smtp = [] # source: airflow/providers/smtp/provider.yaml
snowflake = [ # source: airflow/providers/snowflake/provider.yaml
"apache-airflow[common_sql]",
"snowflake-connector-python>=2.7.8",
Expand Down
76 changes: 31 additions & 45 deletions scripts/ci/pre_commit/pre_commit_update_providers_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
# under the License.
from __future__ import annotations

import hashlib
import json
import os
import re
import sys
from ast import Import, ImportFrom, NodeVisitor, parse
from collections import defaultdict
Expand Down Expand Up @@ -238,20 +236,25 @@ def get_python_exclusion(dependency_info: dict[str, list[str] | str]):
for dependency, dependency_info in dependencies.items():
if dependency_info["state"] in ["suspended", "removed"]:
continue
deps = dependency_info["deps"]
deps = [dep for dep in deps if not dep.startswith("apache-airflow>=")]
devel_deps = dependency_info.get("devel-deps")
if not deps and not devel_deps:
result_content.append(
f"{normalize_extra(dependency)} = [] "
f"# source: airflow/providers/{dependency.replace('.', '/')}/provider.yaml"
)
continue
result_content.append(
f"{normalize_extra(dependency)} = "
f"[ # source: airflow/providers/{dependency.replace('.', '/')}/provider.yaml"
)
deps = dependency_info["deps"]
if not isinstance(deps, list):
raise TypeError(f"Wrong type of 'deps' {deps} for {dependency} in {DEPENDENCIES_JSON_FILE_PATH}")
for dep in deps:
if dep.startswith("apache-airflow-providers-"):
dep = convert_to_extra_dependency(dep)
elif dep.startswith("apache-airflow>="):
continue
result_content.append(f' "{dep}{get_python_exclusion(dependency_info)}",')
devel_deps = dependency_info.get("devel-deps")
if devel_deps:
result_content.append(f" # Devel dependencies for the {dependency} provider")
for dep in devel_deps:
Expand Down Expand Up @@ -284,7 +287,7 @@ def get_dependency_type(dependency_type: str) -> ParsedDependencyTypes | None:
return None


def update_pyproject_toml(dependencies: dict[str, dict[str, list[str] | str]], dependencies_hash: str):
def update_pyproject_toml(dependencies: dict[str, dict[str, list[str] | str]]) -> bool:
file_content = PYPROJECT_TOML_FILE_PATH.read_text()
result_content: list[str] = []
copying = True
Expand All @@ -295,7 +298,6 @@ def update_pyproject_toml(dependencies: dict[str, dict[str, list[str] | str]], d
result_content.append(line)
if line.strip().startswith(GENERATED_DEPENDENCIES_START):
copying = False
result_content.append(f"# Hash of dependencies: {dependencies_hash}")
generate_dependencies(result_content, dependencies)
elif line.strip().startswith(GENERATED_DEPENDENCIES_END):
copying = True
Expand All @@ -320,25 +322,13 @@ def update_pyproject_toml(dependencies: dict[str, dict[str, list[str] | str]], d
if line.strip().endswith(" = ["):
FOUND_EXTRAS[current_type].append(line.split(" = [")[0].strip())
line_count += 1
PYPROJECT_TOML_FILE_PATH.write_text("\n".join(result_content) + "\n")


def calculate_my_hash():
my_file = MY_FILE.resolve()
hash_md5 = hashlib.md5()
hash_md5.update(my_file.read_bytes())
return hash_md5.hexdigest()


def calculate_dependencies_hash(dependencies: str):
my_file = MY_FILE.resolve()
hash_md5 = hashlib.md5()
hash_md5.update(my_file.read_bytes())
hash_md5.update(dependencies.encode(encoding="utf-8"))
return hash_md5.hexdigest()

result_content.append("")
new_file_content = "\n".join(result_content)
if file_content != new_file_content:
PYPROJECT_TOML_FILE_PATH.write_text(new_file_content)
return True
return False

HASH_REGEXP = re.compile(r"# Hash of dependencies: (?P<hash>[a-f0-9]+)")

if __name__ == "__main__":
find_all_providers_and_provider_files()
Expand Down Expand Up @@ -381,16 +371,10 @@ def calculate_dependencies_hash(dependencies: str):
)
new_dependencies = json.dumps(unique_sorted_dependencies, indent=2) + "\n"
old_md5sum = MY_MD5SUM_FILE.read_text().strip() if MY_MD5SUM_FILE.exists() else ""
new_md5sum = calculate_my_hash()
find_hash = HASH_REGEXP.findall(PYPROJECT_TOML_FILE_PATH.read_text())
dependencies_hash_from_pyproject_toml = find_hash[0] if find_hash else ""
dependencies_hash = calculate_dependencies_hash(new_dependencies)
if (
new_dependencies != old_dependencies
or new_md5sum != old_md5sum
or dependencies_hash_from_pyproject_toml != dependencies_hash
):
DEPENDENCIES_JSON_FILE_PATH.write_text(json.dumps(unique_sorted_dependencies, indent=2) + "\n")
old_content = DEPENDENCIES_JSON_FILE_PATH.read_text() if DEPENDENCIES_JSON_FILE_PATH.exists() else ""
new_content = json.dumps(unique_sorted_dependencies, indent=2) + "\n"
DEPENDENCIES_JSON_FILE_PATH.write_text(new_content)
if new_content != old_content:
if os.environ.get("CI"):
console.print()
console.print(f"There is a need to regenerate {DEPENDENCIES_JSON_FILE_PATH}")
Expand All @@ -408,13 +392,15 @@ def calculate_dependencies_hash(dependencies: str):
)
console.print(f"Written {DEPENDENCIES_JSON_FILE_PATH}")
console.print()
update_pyproject_toml(unique_sorted_dependencies, dependencies_hash)
console.print(f"Written {PYPROJECT_TOML_FILE_PATH}")
if update_pyproject_toml(unique_sorted_dependencies):
if os.environ.get("CI"):
console.print(f"There is a need to regenerate {PYPROJECT_TOML_FILE_PATH}")
console.print(
f"[red]You need to run the following command locally and commit generated "
f"{PYPROJECT_TOML_FILE_PATH.relative_to(AIRFLOW_SOURCES_ROOT)} file:\n"
)
console.print("breeze static-checks --type update-providers-dependencies --all-files")
console.print()
MY_MD5SUM_FILE.write_text(new_md5sum + "\n")
sys.exit(1)
else:
console.print(
"[green]No need to regenerate dependencies!\n[/]"
f"The {DEPENDENCIES_JSON_FILE_PATH.relative_to(AIRFLOW_SOURCES_ROOT)} is up to date!\n"
)
else:
console.print(f"Written {PYPROJECT_TOML_FILE_PATH}")
console.print()

0 comments on commit e11a111

Please sign in to comment.