From 5ce0d9d9c569c7172dd2284a45bf67ff1ba68b31 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 Nov 2024 10:41:25 +0100 Subject: [PATCH 01/43] build: Bump gridtools-cpp to 2.3.7 in preparation of #1648 (#1732) #1648 exposed a compilation problem with nvcc which has been fixed in https://github.com/GridTools/gridtools/pull/1811 included in gridtools 2.3.7. --- .pre-commit-config.yaml | 8 ++++---- constraints.txt | 16 ++++++++-------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 16 ++++++++-------- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2f5b73613..93ea4685f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.2 + rev: v0.7.3 ##[[[end]]] hooks: # Run the linter. @@ -97,14 +97,14 @@ repos: - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.30.5 + - cmake==3.31.0.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - - gridtools-cpp==2.3.6 + - gridtools-cpp==2.3.7 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 @@ -112,7 +112,7 @@ repos: - nanobind==2.2.0 - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==24.1 + - packaging==24.2 - pybind11==2.13.6 - setuptools==75.3.0 - tabulate==0.9.0 diff --git a/constraints.txt b/constraints.txt index e7acc466cd..4aca6645d5 100644 --- a/constraints.txt +++ b/constraints.txt @@ -25,7 +25,7 @@ chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via gt4py (pyproject.toml) +cmake==3.31.0.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==0.16.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.7 # via ipykernel +debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) @@ -55,7 +55,7 @@ fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.6 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via gt4py (pyproject.toml) hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via pre-commit idna==3.10 # via requests @@ -66,7 +66,7 @@ inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake ipython==8.12.3 # via ipykernel -jedi==0.19.1 # via ipython +jedi==0.19.2 # via ipython jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff -packaging==24.1 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.2 # via -r requirements-dev.in +ruff==0.7.3 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.2 # via -r requirements-dev.in +tach==0.14.3 # via -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version @@ -174,7 +174,7 @@ virtualenv==20.27.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit websockets==13.1 # via dace -wheel==0.44.0 # via astunparse, pip-tools +wheel==0.45.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index f63042906c..6fd3d1af55 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -68,7 +68,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 666aa79107..b8779096c0 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -64,7 +64,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index c9f7b3b50b..7d63f70f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.6,==2.*', + 'gridtools-cpp>=2.3.7,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index a036307e80..8892620786 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,7 +25,7 @@ chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.7 # via -c constraints.txt, ipykernel +debugpy==1.8.8 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -66,7 +66,7 @@ inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake ipython==8.12.3 # via -c constraints.txt, ipykernel -jedi==0.19.1 # via -c constraints.txt, ipython +jedi==0.19.2 # via -c constraints.txt, ipython jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.1 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in +ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version @@ -173,7 +173,7 @@ virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit websockets==13.1 # via -c constraints.txt, dace -wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources From 89fea8fa86eceea7e7cadb1b16924354267ebf0c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 14:44:23 +0100 Subject: [PATCH 02/43] bug[next]: Fix ITIR program hash stability (#1733) #1690 included a change to make the hash of an `itir.FencilDefinition` stable across multiple runs. This PR adopts the same change to an `itir.Program`, --- src/gt4py/next/iterator/ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index f50d8080eb..7098e9fa2e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -242,7 +242,9 @@ class Program(Node, ValidatedSymbolTableTrait): body: List[Stmt] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(GTIR_BUILTINS) + ] # sorted for serialization stability # TODO(fthaler): just use hashable types in nodes (tuples instead of lists) From b60ffff3f14d272cfc5ee470c80b460358fd8add Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 17:46:41 +0100 Subject: [PATCH 03/43] build: Bump gridtools-cpp to 2.3.8 in preparation of #1648 (#1737) #1648 exposed a compilation problem with nvcc which has been fixed in #1812 included in gridtools 2.3.8. --- .pre-commit-config.yaml | 2 +- constraints.txt | 10 +++++----- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 10 +++++----- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93ea4685f4..f56e84f8d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -104,7 +104,7 @@ repos: - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - - gridtools-cpp==2.3.7 + - gridtools-cpp==2.3.8 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 diff --git a/constraints.txt b/constraints.txt index 4aca6645d5..4247f4951d 100644 --- a/constraints.txt +++ b/constraints.txt @@ -47,7 +47,7 @@ exceptiongroup==1.2.2 # via hypothesis, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==30.8.2 # via factory-boy +faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat filelock==3.16.1 # via tox, virtualenv fonttools==4.54.1 # via matplotlib @@ -55,7 +55,7 @@ fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.7 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.8 # via gt4py (pyproject.toml) hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via pre-commit idna==3.10 # via requests @@ -137,7 +137,7 @@ questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach -rich-click==1.8.3 # via bump-my-version +rich-click==1.8.4 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing ruff==0.7.3 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) @@ -147,7 +147,7 @@ smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==3.0.1 # via -r requirements-dev.in +sphinx-rtd-theme==3.0.2 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -160,7 +160,7 @@ stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) tach==0.14.3 # via -r requirements-dev.in -tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version toolz==1.0.0 # via cytoolz diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 6fd3d1af55..4190570105 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -68,7 +68,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.7 +gridtools-cpp==2.3.8 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index b8779096c0..81a1c2dea3 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -64,7 +64,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.7 +gridtools-cpp==2.3.8 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 7d63f70f15..1504c8b17b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.7,==2.*', + 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index 8892620786..ca7eb32487 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -47,7 +47,7 @@ exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==30.8.2 # via -c constraints.txt, factory-boy +faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, tox, virtualenv fonttools==4.54.1 # via -c constraints.txt, matplotlib @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.8 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -137,7 +137,7 @@ questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.3 # via -c constraints.txt, bump-my-version +rich-click==1.8.4 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser @@ -146,7 +146,7 @@ smmap==5.0.1 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis sphinx==7.1.2 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==3.0.1 # via -c constraints.txt, -r requirements-dev.in +sphinx-rtd-theme==3.0.2 # via -c constraints.txt, -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via -c constraints.txt, sphinx sphinxcontrib-devhelp==1.0.2 # via -c constraints.txt, sphinx sphinxcontrib-htmlhelp==2.0.1 # via -c constraints.txt, sphinx @@ -159,7 +159,7 @@ stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in -tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version toolz==1.0.0 # via -c constraints.txt, cytoolz From c51bdd1b6e515b2cebff466876d51b1bc0874096 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 10:35:42 +0100 Subject: [PATCH 04/43] fix[next]: Fix type preservation in CSE (#1736) The common subexpression elimination uses typing information to decide what expressions can be extracted. However, while extracting it creates new nodes and uses the inline lambda pass, which did not preserve the types. This was observed in PMAP and is fixed in this PR on a best effort basis. Creating a minimal reproducible example is hard and since multiple of us are considering making typing information an integral part of the IR, e.g. by attaching the computation to the node instead of having a separate pass, which would solve the problem automatically no tests have been written. --- src/gt4py/next/iterator/transforms/cse.py | 12 ++++++------ src/gt4py/next/iterator/transforms/inline_lambdas.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 6 +++--- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccc1d2195f..4932d376ad 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -14,6 +14,7 @@ import operator from typing import Callable, Iterable, TypeVar, Union, cast +import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( NodeTranslator, NodeVisitor, @@ -241,7 +242,6 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 @@ -433,7 +433,9 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if num_occurences > 1: if is_local_view: return True - else: + # condition is only necessary since typing on lambdas is not preserved during + # the transformation + elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` @@ -451,10 +453,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int): return self.generic_visit(node, **kwargs) # apply remapping - result = itir.FunCall( - fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), - args=list(extracted.values()), - ) + result = im.let(*extracted.items())(new_expr) + itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True) # if the node id is ignored (because its parent is eliminated), but it occurs # multiple times then we want to visit the final result once more. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..399a7a3dc6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.type_system import inference as itir_inference # TODO(tehrengruber): Reduce complexity of the function by removing the different options here @@ -98,7 +99,7 @@ def new_name(name): new_expr.location = node.location return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +111,8 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index edcb9b540c..66d8345b94 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ - assert isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) + assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) # type: ignore[arg-type] def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: From 998f2792de75650447a1ffd96aec1e4ebc8dc882 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 11:27:56 +0100 Subject: [PATCH 05/43] feat[next]: GTIR embedded and GTFN temporaries with new lowering (#1648) Use new lowering for GTIR embedded, and GTFN. Only the dace iterator backend continues to use the old lowering. Changes: - Use GTIR lowering for all backends except for dace - Old lowering and transformations only used in dace backend - workflows defined in [`gt4py.next.backend.LEGACY_TRANSFORMS`](https://github.com/GridTools/gt4py/pull/1648/files#diff-cf4385d02cbeacc310d4326350903b4cb6f9a61c7cd36dda162a5077ab8b8e86). Variable can be removed in a cleanup PR. - old `apply_common_transforms` in [pass_manager_legacy.py](https://github.com/GridTools/gt4py/pull/1648/files#diff-db17bff48ac16ee75ff974a1b9af98e3cf0c850971ce9898aa55b635bb046b72). Just a straight copy of the old function. No need to review, this is just to avoid deleting until gtir based dace backend is ready. - Re-add `symbolic_sizes` param. Was in temporary extraction, is now part of the domain inference. In preparation of icon-exclaim tests --------- Co-authored-by: Hannes Vogt --- .pre-commit-config.yaml | 1 - src/gt4py/next/backend.py | 14 +- src/gt4py/next/ffront/foast_to_past.py | 6 +- src/gt4py/next/ffront/past_to_itir.py | 2 +- .../next/iterator/ir_utils/domain_utils.py | 23 ++- src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 +- .../next/iterator/transforms/__init__.py | 3 +- .../iterator/transforms/collapse_list_get.py | 52 +++--- .../iterator/transforms/collapse_tuple.py | 85 ++++++--- src/gt4py/next/iterator/transforms/cse.py | 43 +++-- .../iterator/transforms/fuse_as_fieldop.py | 44 +++-- .../next/iterator/transforms/global_tmps.py | 8 +- .../next/iterator/transforms/infer_domain.py | 140 ++++++++++---- .../iterator/transforms/inline_into_scan.py | 2 +- .../iterator/transforms/inline_lambdas.py | 7 +- .../next/iterator/transforms/inline_scalar.py | 31 +++ .../next/iterator/transforms/pass_manager.py | 176 ++++++------------ .../transforms/pass_manager_legacy.py | 175 +++++++++++++++++ .../next/iterator/transforms/remap_symbols.py | 8 +- .../next/iterator/transforms/unroll_reduce.py | 9 +- .../iterator/type_system/type_synthesizer.py | 6 +- .../codegens/gtfn/gtfn_ir.py | 58 ++++-- .../codegens/gtfn/gtfn_module.py | 17 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 26 ++- .../program_processors/formatters/lisp.py | 4 +- .../next/program_processors/runners/dace.py | 11 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/workflow.py | 8 +- .../next/program_processors/runners/gtfn.py | 20 +- .../program_processors/runners/roundtrip.py | 16 +- src/gt4py/next/type_system/type_info.py | 4 +- tests/next_tests/definitions.py | 62 +++--- .../ffront_tests/ffront_test_utils.py | 1 - .../ffront_tests/test_decorator.py | 4 +- .../ffront_tests/test_execution.py | 24 ++- .../ffront_tests/test_scalar_if.py | 1 + .../test_temporaries_with_sizes.py | 26 +-- .../iterator_tests/test_builtins.py | 15 +- .../feature_tests/iterator_tests/test_scan.py | 2 + .../ffront_tests/test_icon_like_scan.py | 19 -- .../iterator_tests/test_anton_toy.py | 5 - .../iterator_tests/test_column_stencil.py | 51 ++--- .../iterator_tests/test_fvm_nabla.py | 1 - .../iterator_tests/test_if_stmt.py | 6 +- .../iterator_tests/test_vertical_advection.py | 58 ++---- .../test_with_toy_connectivity.py | 1 - tests/next_tests/unit_tests/conftest.py | 54 +++--- .../transforms_tests/test_collapse_tuple.py | 21 ++- .../transforms_tests/test_cse.py | 26 +-- .../transforms_tests/test_domain_inference.py | 25 ++- .../transforms_tests/test_fuse_as_fieldop.py | 25 +++ .../transforms_tests/test_unroll_reduce.py | 87 ++------- 52 files changed, 945 insertions(+), 586 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_scalar.py create mode 100644 src/gt4py/next/iterator/transforms/pass_manager_legacy.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f56e84f8d9..07f75177ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,6 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml - - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit ##[[[cog diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 0340d61f89..e223d7771c 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -15,6 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( + foast_to_gtir, foast_to_itir, foast_to_past, func_to_foast, @@ -76,7 +77,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] = dataclasses.field( - default_factory=foast_to_itir.adapted_foast_to_itir_factory + default_factory=foast_to_gtir.adapted_foast_to_gtir_factory ) field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field( @@ -134,6 +135,17 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() +# FIXME[#1582](havogt): remove after refactoring to GTIR +# note: this step is deliberately placed here, such that the cache is shared +_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) +LEGACY_TRANSFORMS: Transforms = Transforms( + past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), + foast_to_itir=_foast_to_itir_step, + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=_foast_to_itir_step + ), +) + # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 312ac686a2..330bc79809 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -12,7 +12,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next.ffront import ( dialect_ast_enums, - foast_to_itir, + foast_to_gtir, program_ast as past, stages as ffront_stages, type_specifications as ts_ffront, @@ -68,7 +68,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: ... return a - >>> op_to_prog = OperatorToProgram(foast_to_itir.adapted_foast_to_itir_factory()) + >>> op_to_prog = OperatorToProgram(foast_to_gtir.adapted_foast_to_gtir_factory()) >>> compile_time_args = arguments.CompileTimeArgs( ... args=tuple(param.type for param in copy.foast_stage.foast_node.definition.params), @@ -169,7 +169,7 @@ def operator_to_program_factory( ) -> workflow.Workflow[AOT_FOP, AOT_PRG]: """Optionally wrap `OperatorToProgram` in a `CachedStep`.""" wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram( - foast_to_itir_step or foast_to_itir.adapted_foast_to_itir_factory() + foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 14d705576e..c0348bb5c6 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -108,7 +108,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR def past_to_itir_factory( - cached: bool = True, to_gtir: bool = False + cached: bool = True, to_gtir: bool = True ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) if cached: diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8eec405136..8f842e1c13 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Literal, Mapping +from typing import Any, Literal, Mapping, Optional import gt4py.next as gtx from gt4py.next import common @@ -93,6 +93,9 @@ def translate( ..., ], offset_provider: common.OffsetProvider, + #: A dictionary mapping axes names to their length. See + #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} @@ -119,18 +122,24 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - # note: ugly but cheap re-computation, but should disappear - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) + horizontal_sizes: dict[str, itir.Expr] + if symbolic_domain_sizes is not None: + horizontal_sizes = {k: im.ref(v) for k, v in symbolic_domain_sizes.items()} + else: + # note: ugly but cheap re-computation, but should disappear + horizontal_sizes = { + k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN) + for k, v in _max_domain_sizes_by_location_type(offset_provider).items() + } old_dim = nbt_provider.origin_axis new_dim = nbt_provider.neighbor_axis assert new_dim not in new_ranges or old_dim == new_dim - # TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? new_range = SymbolicRange( im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN), + horizontal_sizes[new_dim.value], ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) @@ -140,7 +149,9 @@ def translate( raise AssertionError() return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: - return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) + return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( + shift[2:], offset_provider, symbolic_domain_sizes + ) else: raise AssertionError("Number of shifts must be a multiple of 2.") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 19e26f24b6..d7a66b8285 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -10,7 +10,6 @@ from typing import Callable, Optional, Union from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Dict, Tuple from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts, type_translation @@ -412,7 +411,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]], + ranges: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]], ) -> itir.FunCall: """ >>> str( @@ -446,7 +445,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index 6f9651a397..aeccb5f26d 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -8,10 +8,9 @@ from gt4py.next.iterator.transforms.pass_manager import ( ITIRTransform, - LiftMode, apply_common_transforms, apply_fieldview_transforms, ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "LiftMode", "ITIRTransform"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index f8a3c08e8f..4a354879ca 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): @@ -18,32 +19,29 @@ class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): - `list_get(i, make_const_list(e))` -> `e` """ - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: node = self.generic_visit(node) - if node.fun == ir.SymRef(id="list_get"): - if isinstance(node.args[1], ir.FunCall): - if node.args[1].fun == ir.SymRef(id="neighbors"): - offset_tag = node.args[1].args[0] - offset_index = ( - ir.OffsetLiteral(value=int(node.args[0].value)) - if isinstance(node.args[0], ir.Literal) - else node.args[ - 0 - ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn - ) - it = node.args[1].args[1] - return ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), args=[offset_tag, offset_index] - ), - args=[it], - ) - ], - ) - if node.args[1].fun == ir.SymRef(id="make_const_list"): - return node.args[1].args[0] + if cpm.is_call_to(node, "list_get"): + if cpm.is_call_to(node.args[1], "if_"): + list_idx = node.args[0] + cond, true_val, false_val = node.args[1].args + return im.if_( + cond, + self.visit(im.call("list_get")(list_idx, true_val)), + self.visit(im.call("list_get")(list_idx, false_val)), + ) + if cpm.is_call_to(node.args[1], "neighbors"): + offset_tag = node.args[1].args[0] + offset_index = ( + itir.OffsetLiteral(value=int(node.args[0].value)) + if isinstance(node.args[0], itir.Literal) + else node.args[ + 0 + ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn + ) + it = node.args[1].args[1] + return im.deref(im.shift(offset_tag, offset_index)(it)) + if cpm.is_call_to(node.args[1], "make_const_list"): + return node.args[1].args[0] return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b61fb2ba87..f84714e779 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -16,6 +16,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, @@ -104,9 +105,10 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider=None, + offset_provider: Optional[common.OffsetProvider] = None, + within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes - flags=None, + flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, ) -> ir.Node: @@ -126,6 +128,13 @@ def apply( flags = flags or cls.flags offset_provider = offset_provider or {} + if isinstance(node, (ir.Program, ir.FencilDefinition)): + within_stencil = False + assert within_stencil in [ + True, + False, + ], "Parameter 'within_stencil' mandatory if node is not a 'Program'." + if not ignore_tuple_size: node = itir_type_inference.infer( node, @@ -136,7 +145,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags, - ).visit(node) + ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -150,20 +159,23 @@ def apply( return new_node - def visit_FunCall(self, node: ir.FunCall) -> ir.Node: - node = self.generic_visit(node) - return self.fp_transform(node) + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + if cpm.is_call_to(node, "as_fieldop"): + kwargs = {**kwargs, "within_stencil": True} + + node = self.generic_visit(node, **kwargs) + return self.fp_transform(node, **kwargs) - def fp_transform(self, node: ir.Node) -> ir.Node: + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: while True: - new_node = self.transform(node) + new_node = self.transform(node, **kwargs) if new_node is None: break assert new_node != node node = new_node return node - def transform(self, node: ir.Node) -> Optional[ir.Node]: + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: if not isinstance(node, ir.FunCall): return None @@ -171,12 +183,14 @@ def transform(self, node: ir.Node) -> Optional[ir.Node]: if self.flags & transformation: assert isinstance(transformation.name, str) method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node) + result = method(node, **kwargs) if result is not None: return result return None - def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_make_tuple_tuple_get( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="make_tuple") and all( isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") for arg in node.args @@ -202,7 +216,9 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ return first_expr return None - def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_tuple_get_make_tuple( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if ( node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) @@ -219,7 +235,7 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture @@ -228,7 +244,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let - self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr), **kwargs) # type: ignore[attr-defined] # ensured by is_let ) )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let @@ -238,12 +254,12 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: cond, true_branch, false_branch = node.args[1].args return im.if_( cond, - self.fp_transform(im.tuple_get(idx.value, true_branch)), - self.fp_transform(im.tuple_get(idx.value, false_branch)), + self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), + self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), ) return None - def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` @@ -258,21 +274,27 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. new_args.append(arg) if bound_vars: - return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) + return self.fp_transform( + im.let(*bound_vars.items())(im.call(node.fun)(*new_args)), **kwargs + ) return None - def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): - return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + if kwargs["within_stencil"]: + # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation + # in local-view for now. Revisit. + return None + if not cpm.is_call_to(node, "if_"): - # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -281,12 +303,16 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args - new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) - new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) + new_true_branch = self.fp_transform( + _with_altered_arg(node, i, true_branch), **kwargs + ) + new_false_branch = self.fp_transform( + _with_altered_arg(node, i, false_branch), **kwargs + ) return im.if_(cond, new_true_branch, new_false_branch) return None - def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -304,12 +330,15 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: if outer_vars: return self.fp_transform( im.let(*outer_vars.items())( - self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) - ) + self.fp_transform( + im.let(*inner_vars.items())(original_inner_expr), **kwargs + ) + ), + **kwargs, ) return None - def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 4932d376ad..38ea1fd53d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -31,6 +31,21 @@ from gt4py.next.type_system import type_info, type_specifications as ts +def _is_trivial_tuple_expr(node: itir.Expr): + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" + if cpm.is_call_to(node, "make_tuple") and all( + isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_tuple_expr(arg) + for arg in node.args + ): + return True + if cpm.is_call_to(node, "tuple_get") and ( + isinstance(node.args[1], (itir.SymRef, itir.Literal)) + or _is_trivial_tuple_expr(node.args[1]) + ): + return True + return False + + @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type", "domain") @@ -373,7 +388,7 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): >>> x = itir.SymRef(id="x") >>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b]) >>> expr = plus(plus(x, x), plus(x, x)) - >>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True)) + >>> print(CommonSubexpressionElimination.apply(expr, within_stencil=True)) (λ(_cs_1) → _cs_1 + _cs_1)(x + x) The pass visits the tree top-down starting from the root node, e.g. an itir.Program. @@ -395,33 +410,33 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): def apply( cls, node: ProgramOrExpr, - is_local_view: bool | None = None, + within_stencil: bool | None = None, offset_provider: common.OffsetProvider | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: - assert is_local_view is None - is_local_view = False + assert within_stencil is None + within_stencil = False else: assert ( - is_local_view is not None - ), "The expression's context must be specified using `is_local_view`." + within_stencil is not None + ), "The expression's context must be specified using `within_stencil`." offset_provider = offset_provider or {} node = itir_type_inference.infer( node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program ) - return cls().visit(node, is_local_view=is_local_view) + return cls().visit(node, within_stencil=within_stencil) def generic_visit(self, node, **kwargs): if cpm.is_call_to("as_fieldop", node): - assert not kwargs.get("is_local_view") - is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view") + assert not kwargs.get("within_stencil") + within_stencil = cpm.is_call_to("as_fieldop", node) or kwargs.get("within_stencil") - return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view})) + return super().generic_visit(node, **(kwargs | {"within_stencil": within_stencil})) def visit_FunCall(self, node: itir.FunCall, **kwargs): - is_local_view = kwargs["is_local_view"] + within_stencil = kwargs["within_stencil"] if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): return node @@ -431,7 +446,7 @@ def predicate(subexpr: itir.Expr, num_occurences: int): # view, even though the syntactic context `node` is in field view. # note: what is extracted is sketched in the docstring above. keep it updated. if num_occurences > 1: - if is_local_view: + if within_stencil: return True # condition is only necessary since typing on lambdas is not preserved during # the transformation @@ -439,11 +454,13 @@ def predicate(subexpr: itir.Expr, num_occurences: int): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + # only extract if subexpression is not a trivial tuple expressions, e.g., + # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) for stype in type_info.primitive_constituents(subexpr.type) - ): + ) and not _is_trivial_tuple_expr(subexpr): return True return False diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 51bbd91d83..da238733da 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -13,7 +13,12 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts +from gt4py.next.iterator.transforms import ( + inline_center_deref_lift_vars, + inline_lambdas, + inline_lifts, + trace_shifts, +) from gt4py.next.iterator.type_system import ( inference as type_inference, type_specifications as it_ts, @@ -54,6 +59,14 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return expr +def _is_tuple_expr_of_literals(expr: itir.Expr): + if cpm.is_call_to(expr, "make_tuple"): + return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) + if cpm.is_call_to(expr, "tuple_get"): + return _is_tuple_expr_of_literals(expr.args[1]) + return isinstance(expr, itir.Literal) + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -153,11 +166,15 @@ def visit_FunCall(self, node: itir.FunCall): for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.extract_dtype(arg.type) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = isinstance(arg, itir.Literal) or ( + should_inline = _is_tuple_expr_of_literals(arg) or ( isinstance(arg, itir.FunCall) - and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_")) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) if should_inline: @@ -168,7 +185,7 @@ def visit_FunCall(self, node: itir.FunCall): type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) arg.type = type_ - elif isinstance(arg, itir.Literal): + elif _is_tuple_expr_of_literals(arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: raise NotImplementedError() @@ -179,6 +196,7 @@ def visit_FunCall(self, node: itir.FunCall): new_args = _merge_arguments(new_args, extracted_args) else: + assert not isinstance(dtype, it_ts.ListType) new_param: str if isinstance( arg, itir.SymRef @@ -189,15 +207,19 @@ def visit_FunCall(self, node: itir.FunCall): new_param = stencil_param.id new_args = _merge_arguments(new_args, {new_param: arg}) - # simplify stencil directly to keep the tree small - new_stencil_body = inline_lambdas.InlineLambdas.apply( - new_stencil_body, opcount_preserving=True - ) - new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( *new_args.values() ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + type_inference.copy_type(from_=node, to=new_node) return new_node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 11d3fccec1..90f8a6cded 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -175,7 +175,10 @@ def _transform_stmt( def create_global_tmps( - program: itir.Program, offset_provider: common.OffsetProvider + program: itir.Program, + offset_provider: common.OffsetProvider, + *, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> itir.Program: """ Given an `itir.Program` create temporaries for intermediate values. @@ -186,7 +189,8 @@ def create_global_tmps( program = infer_domain.infer_program(program, offset_provider) program = type_inference.infer(program, offset_provider=offset_provider) - uids = eve_utils.UIDGenerator(prefix="__tmp") + if not uids: + uids = eve_utils.UIDGenerator(prefix="__tmp") declarations = program.declarations.copy() new_body = [] diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 2a85e6f2cf..6852b47a7a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,8 +10,9 @@ import itertools import typing -from typing import Callable, TypeAlias +from typing import Callable, Optional, TypeAlias +from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir @@ -28,6 +29,18 @@ ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAnnexDebugger(eve.NodeVisitor): + """ + Small utility class to debug missing domain attribute in annex. + """ + + def visit_Node(self, node: itir.Node): + if cpm.is_applied_as_fieldop(node): + if not hasattr(node.annex, "domain"): + breakpoint() # noqa: T100 + return self.generic_visit(node) + + def _split_dict_by_key(pred: Callable, d: dict): """ Split dictionary into two based on predicate. @@ -107,6 +120,7 @@ def _extract_accessed_domains( input_ids: list[str], target_domain: domain_utils.SymbolicDomain, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> ACCESSED_DOMAINS: accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} @@ -114,7 +128,9 @@ def _extract_accessed_domains( for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ - domain_utils.SymbolicDomain.translate(target_domain, shift, offset_provider) + domain_utils.SymbolicDomain.translate( + target_domain, shift, offset_provider, symbolic_domain_sizes + ) for shift in shifts_list ] # `None` means field is never accessed @@ -125,10 +141,11 @@ def _extract_accessed_domains( return typing.cast(ACCESSED_DOMAINS, accessed_domains) -def infer_as_fieldop( +def _infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") @@ -161,7 +178,7 @@ def infer_as_fieldop( input_ids.append(id_) inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( - stencil, input_ids, target_domain, offset_provider + stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s @@ -169,7 +186,7 @@ def infer_as_fieldop( transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider + in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes ) transformed_inputs.append(transformed_input) @@ -187,15 +204,16 @@ def infer_as_fieldop( return transformed_call, accessed_domains_without_tmp -def infer_let( +def _infer_let( let_expr: itir.FunCall, input_domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider + let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes ) let_params = {param_sym.id for param_sym in let_expr.fun.params} @@ -212,6 +230,7 @@ def infer_let( None, ), offset_provider, + symbolic_domain_sizes, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -226,10 +245,11 @@ def infer_let( return transformed_call, accessed_domains_outer -def infer_make_tuple( +def _infer_make_tuple( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] @@ -245,17 +265,20 @@ def infer_make_tuple( # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + arg, domain[i], offset_provider, symbolic_domain_sizes + ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) return result_expr, actual_domains -def infer_tuple_get( +def _infer_tuple_get( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} @@ -263,24 +286,29 @@ def infer_tuple_get( assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + ) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) return infered_args_expr, actual_domains -def infer_if( +def _infer_if( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] actual_domains: ACCESSED_DOMAINS = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + arg, domain, offset_provider, symbolic_domain_sizes + ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -291,25 +319,26 @@ def _infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return infer_as_fieldop(expr, domain, offset_provider) + return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_let(expr): - return infer_let(expr, domain, offset_provider) + return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "make_tuple"): - return infer_make_tuple(expr, domain, offset_provider) + return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "tuple_get"): - return infer_tuple_get(expr, domain, offset_provider) + return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "if_"): - return infer_if(expr, domain, offset_provider) + return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) - or cpm.is_call_to(expr, "cast_") + or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} else: @@ -320,40 +349,79 @@ def infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + """ + Infer the domain of all field subexpressions of `expr`. + + Given an expression `expr` and the domain it is accessed at, back-propagate the domain of all + (field-typed) subexpression. + + Arguments: + - expr: The expression to be inferred. + - domain: The domain `expr` is read at. + - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol + name that evaluates to the length of that axis. + + Returns: + A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) + having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to + domain they are accessed at. + """ # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider) + expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) expr.annex.domain = domain return expr, accessed_domains +def _infer_stmt( + stmt: itir.Stmt, + offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], +): + if isinstance(stmt, itir.SetAt): + transformed_call, _unused_domain = infer_expr( + stmt.expr, + domain_utils.SymbolicDomain.from_expr(stmt.domain), + offset_provider, + symbolic_domain_sizes, + ) + return itir.SetAt( + expr=transformed_call, + domain=stmt.domain, + target=stmt.target, + ) + elif isinstance(stmt, itir.IfStmt): + return itir.IfStmt( + cond=stmt.cond, + true_branch=[ + _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch + ], + false_branch=[ + _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch + ], + ) + raise ValueError(f"Unsupported stmt: {stmt}") + + def infer_program( program: itir.Program, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - transformed_set_ats: list[itir.SetAt] = [] + """ + Infer the domain of all field subexpressions inside a program. + + See :func:`infer_expr` for more details. + """ assert ( not program.function_definitions ), "Domain propagation does not support function definitions." - for set_at in program.body: - assert isinstance(set_at, itir.SetAt) - - transformed_call, _unused_domain = infer_expr( - set_at.expr, domain_utils.SymbolicDomain.from_expr(set_at.domain), offset_provider - ) - transformed_set_ats.append( - itir.SetAt( - expr=transformed_call, - domain=set_at.domain, - target=set_at.target, - ), - ) - return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=transformed_set_ats, + body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], ) diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index f899da73b1..33e36bfa4b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +# FIXME[#1582](tehrengruber): This transformation is not used anymore. Decide on its fate. from typing import Sequence, TypeGuard from gt4py import eve diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 399a7a3dc6..5ec9ec5d0b 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -111,6 +111,9 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) return new_expr @@ -120,10 +123,10 @@ class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """ Inline lambda calls by substituting every argument by its value. - Note: This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. + Note: This pass preserves, but doesn't use the `type` `recorded_shifts`, `domain` annex. """ - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") opcount_preserving: bool diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py new file mode 100644 index 0000000000..c6e2c38b90 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -0,0 +1,31 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +class InlineScalar(eve.NodeTranslator): + @classmethod + def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): + program = itir_inference.infer(program, offset_provider=offset_provider) + return cls().visit(program) + + def visit_Expr(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_let(node): + eligible_params = [isinstance(arg.type, ts.ScalarType) for arg in node.args] + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_params) + return node + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0c08bf2b9d..52a452155a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,28 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import enum from typing import Callable, Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, infer_domain, inline_fundefs +from gt4py.next.iterator.transforms import ( + fencil_to_program, + fuse_as_fieldop, + global_tmps, + infer_domain, + inline_fundefs, + inline_lifts, +) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.transforms.inline_scalar import InlineScalar from gt4py.next.iterator.transforms.merge_let import MergeLet from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce +from gt4py.next.iterator.type_system.inference import infer class ITIRTransform(Protocol): @@ -36,45 +38,12 @@ def __call__( ) -> itir.Program: ... -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir - - # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward -# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. +# `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: itir.Program | itir.FencilDefinition, *, - lift_mode=None, + extract_temporaries=False, offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, @@ -84,57 +53,52 @@ def apply_common_transforms( temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place + #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for + #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram().apply( - ir - ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR - else: - assert isinstance(ir, itir.Program) - # FIXME[#1582](havogt): note: currently the case when using the roundtrip backend - pass + ir = fencil_to_program.FencilToProgram.apply(ir) + assert isinstance(ir, itir.Program) - icdlv_uids = eve_utils.UIDGenerator() + tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") + mergeasfop_uids = eve_utils.UIDGenerator() - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) ir = NormalizeShifts().visit(ir) + # note: this increases the size of the tree + # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) + # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) + ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = infer_domain.infer_program( + ir, # type: ignore[arg-type] # always an itir.Program + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + ) + for _ in range(10): inlined = ir - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # still a `itir.Program` + inlined = InlineLambdas.apply(inlined, opcount_preserving=True) + inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` - inlined, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) + + # This pass is required to run after CollapseTuple as otherwise we can not inline + # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns + # a list. Such expressions must be inlined however because no backend supports such + # field operators right now. + inlined = fuse_as_fieldop.FuseAsFieldOp.apply( + inlined, uids=mergeasfop_uids, offset_provider=offset_provider ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) if inlined == ir: break @@ -142,48 +106,21 @@ def apply_common_transforms( else: raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - if lift_mode != LiftMode.FORCE_INLINE: - # FIXME[#1582](tehrengruber): implement new temporary pass here - raise NotImplementedError() - # ruff: noqa: ERA001 - # assert offset_provider is not None - # ir = CreateGlobalTmps().visit( - # ir, - # offset_provider=offset_provider, - # extraction_heuristics=temporary_extraction_heuristics, - # symbolic_sizes=symbolic_domain_sizes, - # ) - # - # for _ in range(10): - # inlined = InlineLifts().visit(ir) - # inlined = InlineLambdas.apply( - # inlined, opcount_preserving=True, force_inline_lift_args=True - # ) - # if inlined == ir: - # break - # ir = inlined - # else: - # raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - # - # # If after creating temporaries, the scan is not at the top, we inline. - # # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - # ir = _inline_into_scan(ir) + # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = MergeLet().visit(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True) + + if extract_temporaries: + ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` - ir, - ignore_tuple_size=True, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) + ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -198,18 +135,13 @@ def apply_common_transforms( ir = unrolled # type: ignore[assignment] # still a `itir.Program` ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) + # this is required as nested neighbor reductions can contain lifts, e.g., + # `neighbors(V2Eₒ, ↑f(...))` + ir = inline_lifts.InlineLifts().visit(ir) ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) - ir = MergeLet().visit(ir) - ir = InlineLambdas.apply( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py new file mode 100644 index 0000000000..792bb421f1 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -0,0 +1,175 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR +import enum +from typing import Callable, Optional + +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet +from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination +from gt4py.next.iterator.transforms.eta_reduction import EtaReduction +from gt4py.next.iterator.transforms.fuse_maps import FuseMaps +from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars +from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas +from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.transforms.merge_let import MergeLet +from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts +from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref +from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce + + +@enum.unique +class LiftMode(enum.Enum): + FORCE_INLINE = enum.auto() + USE_TEMPORARIES = enum.auto() + + +def _inline_lifts(ir, lift_mode): + if lift_mode == LiftMode.FORCE_INLINE: + return InlineLifts().visit(ir) + elif lift_mode == LiftMode.USE_TEMPORARIES: + return InlineLifts( + flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. + ).visit(ir) + else: + raise ValueError() + + return ir + + +def _inline_into_scan(ir, *, max_iter=10): + for _ in range(10): + # in case there are multiple levels of lambdas around the scan we have to do multiple iterations + inlined = InlineIntoScan().visit(ir) + inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") + return ir + + +def apply_common_transforms( + ir: itir.Node, + *, + lift_mode=None, + offset_provider=None, + unroll_reduce=False, + common_subexpression_elimination=True, + force_inline_lambda_args=False, + unconditionally_collapse_tuples=False, + temporary_extraction_heuristics: Optional[ + Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] + ] = None, + symbolic_domain_sizes: Optional[dict[str, str]] = None, +) -> itir.Program: + assert isinstance(ir, itir.FencilDefinition) + ir = fencil_to_program.FencilToProgram().apply(ir) + icdlv_uids = eve_utils.UIDGenerator() + + if lift_mode is None: + lift_mode = LiftMode.FORCE_INLINE + assert isinstance(lift_mode, LiftMode) + ir = MergeLet().visit(ir) + ir = inline_fundefs.InlineFundefs().visit(ir) + + ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = PropagateDeref.apply(ir) + ir = NormalizeShifts().visit(ir) + + for _ in range(10): + inlined = ir + + inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil + inlined = _inline_lifts(inlined, lift_mode) + + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), + # If trivial lifts are not inlined we might create temporaries for constants. In all + # other cases we want it anyway. + force_inline_trivial_lift_args=True, + ) + inlined = ConstantFolding.apply(inlined) + # This pass is required to be in the loop such that when an `if_` call with tuple arguments + # is constant-folded the surrounding tuple_get calls can be removed. + inlined = CollapseTuple.apply( + inlined, + offset_provider=offset_provider, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + # This pass is required such that a deref outside of a + # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the + # `tuple_get` is removed by the `CollapseTuple` pass. + inlined = PropagateDeref.apply(inlined) + + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + + if lift_mode != LiftMode.FORCE_INLINE: + raise NotImplementedError() + + # Since `CollapseTuple` relies on the type inference which does not support returning tuples + # larger than the number of closure outputs as given by the unconditional collapse, we can + # only run the unconditional version here instead of in the loop above. + if unconditionally_collapse_tuples: + ir = CollapseTuple.apply( + ir, + ignore_tuple_size=True, + offset_provider=offset_provider, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + + if lift_mode == LiftMode.FORCE_INLINE: + ir = _inline_into_scan(ir) + + ir = NormalizeShifts().visit(ir) + + ir = FuseMaps().visit(ir) + ir = CollapseListGet().visit(ir) + + if unroll_reduce: + for _ in range(10): + unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + if unrolled == ir: + break + ir = unrolled + ir = CollapseListGet().visit(ir) + ir = NormalizeShifts().visit(ir) + ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) + ir = NormalizeShifts().visit(ir) + else: + raise RuntimeError("Reduction unrolling failed.") + + ir = EtaReduction().visit(ir) + ir = ScanEtaReduction().visit(ir) + + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = MergeLet().visit(ir) + + ir = InlineLambdas.apply( + ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args + ) + + assert isinstance(ir, itir.Program) + return ir diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 02180a3699..08d896121d 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -13,8 +13,8 @@ class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): return symbol_map.get(str(node.id), node) @@ -32,8 +32,8 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] class RenameSymbols(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 700b8571a5..ec9c3efb2b 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -30,7 +30,14 @@ def _is_neighbors_or_lifted_and_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunC def _get_neighbors_args(reduce_args: Iterable[itir.Expr]) -> Iterator[itir.FunCall]: - return filter(_is_neighbors_or_lifted_and_neighbors, reduce_args) + flat_reduce_args: list[itir.Expr] = [] + for arg in reduce_args: + if cpm.is_call_to(arg, "if_"): + flat_reduce_args.extend(_get_neighbors_args(arg.args[1:3])) + else: + flat_reduce_args.append(arg) + + return filter(_is_neighbors_or_lifted_and_neighbors, flat_reduce_args) def _is_list_of_funcalls(lst: list) -> TypeGuard[list[itir.FunCall]]: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6579107197..43c4465576 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -299,7 +299,11 @@ def as_fieldop( @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: - if any(isinstance(f, ts.DeferredType) for f in fields): + if any( + isinstance(el, ts.DeferredType) + for f in fields + for el in type_info.primitive_constituents(f) + ): return ts.DeferredType(constraint=None) stencil_return = stencil( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 20a1a0cf76..85a100a88d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import ClassVar, Optional, Union +from typing import Callable, ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait @@ -96,25 +96,23 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_ref_literal_or_tuple_expr_of_ref(expr: Expr) -> bool: +def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "tuple_get" and len(expr.args) == 2 - and _is_ref_literal_or_tuple_expr_of_ref(expr.args[1]) + and _is_tuple_expr_of(pred, expr.args[1]) ): return True if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "make_tuple" - and all(_is_ref_literal_or_tuple_expr_of_ref(arg) for arg in expr.args) + and all(_is_tuple_expr_of(pred, arg) for arg in expr.args) ): return True - if isinstance(expr, (SymRef, Literal)): - return True - return False + return pred(expr) class SidComposite(Expr): @@ -126,14 +124,32 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_ref_literal_or_tuple_expr_of_ref(el) + or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) for el in value ): raise ValueError( - "Only 'SymRef', tuple expr of 'SymRef', 'SidFromScalar', or 'SidComposite' allowed." + "Only 'SymRef', 'Literal', tuple expr thereof, 'SidFromScalar', or 'SidComposite' allowed." ) +def _might_be_scalar_expr(expr: Expr) -> bool: + if isinstance(expr, BinaryExpr): + return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) + if isinstance(expr, UnaryExpr): + return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr) + if ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id in ARITHMETIC_BUILTINS + ): + return all(_might_be_scalar_expr(arg) for arg in expr.args) + if isinstance(expr, CastExpr): + return _might_be_scalar_expr(expr.obj_expr) + if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): + return True + return False + + class SidFromScalar(Expr): arg: Expr @@ -141,8 +157,10 @@ class SidFromScalar(Expr): def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: Expr ) -> None: - if not _is_ref_literal_or_tuple_expr_of_ref(value): - raise ValueError("Only 'SymRef' or tuple expr of 'SymRef' allowed.") + if not _might_be_scalar_expr(value): + raise ValueError( + "Only 'SymRef', 'Literal', arithmetic op or tuple expr thereof allowed." + ) class Stmt(Node): @@ -155,6 +173,24 @@ class StencilExecution(Stmt): output: Union[SymRef, SidComposite] inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] + @datamodels.validator("inputs") + def _arg_validator( + self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr] + ) -> None: + for inp in inputs: + if not _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) + or ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id == "index" + ), + inp, + ): + raise ValueError( + "Only 'SymRef', 'SidComposite', 'SidFromScalar', 'index' call or tuple expr thereof allowed." + ) + class Scan(Node): function: SymRef diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 66d74d53cc..ce459f7970 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -20,7 +20,7 @@ from gt4py.next import common from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -51,7 +51,6 @@ class GTFNTranslationStep( # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 enable_itir_transforms: bool = True use_imperative_backend: bool = False - lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None temporary_extraction_heuristics: Optional[ @@ -168,14 +167,9 @@ def _preprocess_program( program: itir.FencilDefinition | itir.Program, offset_provider: dict[str, common.Connectivity | common.Dimension], ) -> itir.Program: - if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms: - return fencil_to_program.FencilToProgram().apply( - program - ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR - apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, - lift_mode=self.lift_mode, + extract_temporaries=True, offset_provider=offset_provider, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, @@ -203,7 +197,12 @@ def generate_stencil_source( offset_provider: dict[str, common.Connectivity | common.Dimension], column_axis: Optional[common.Dimension], ) -> str: - new_program = self._preprocess_program(program, offset_provider) + if self.enable_itir_transforms: + new_program = self._preprocess_program(program, offset_provider) + else: + assert isinstance(program, itir.Program) + new_program = program + gtfn_ir = GTFN_lowering.apply( new_program, offset_provider=offset_provider, column_axis=column_axis ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index fb2645208c..bc2bd645e8 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,7 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, @@ -67,6 +67,27 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: _horizontal_dimension = "gtfn::unstructured::dim::horizontal" +def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "tuple_get" + and len(expr.args) == 2 + and _is_tuple_of_ref_or_literal(expr.args[1]) + ): + return True + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "make_tuple" + and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) + ): + return True + if isinstance(expr, (itir.SymRef, itir.Literal)): + return True + return False + + def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: result = set() for node in nodes: @@ -587,6 +608,9 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: + if _is_tuple_of_ref_or_literal(node.expr): + node.expr = im.as_fieldop("deref", node.domain)(node.expr) + assert cpm.is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py index 7b722a7c1a..0a8253595e 100644 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ b/src/gt4py/next/program_processors/formatters/lisp.py @@ -51,9 +51,7 @@ class ToLispLike(TemplatedGenerator): @classmethod def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms( - root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] - ) + transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) generated_code = super().apply(transformed, **kwargs) try: from yasi import indent_code diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 9a45b6a29a..95186e0b5d 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -9,7 +9,6 @@ import factory from gt4py.next import backend -from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory @@ -33,7 +32,7 @@ class Params: lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" ) - transforms = backend.DEFAULT_TRANSFORMS + transforms = backend.LEGACY_TRANSFORMS run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) @@ -59,13 +58,7 @@ class Params: lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" ) - transforms = backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(), - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() - ), - ) + transforms = backend.DEFAULT_TRANSFORMS gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 6383d4bb44..fc2772027e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -25,7 +25,10 @@ from gt4py.next.ffront import decorator from gt4py.next.iterator import transforms as itir_transforms from gt4py.next.iterator.ir import SymRef -from gt4py.next.iterator.transforms import program_to_fencil +from gt4py.next.iterator.transforms import ( + pass_manager_legacy as legacy_itir_transforms, + program_to_fencil, +) from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.type_system import type_specifications as ts @@ -36,14 +39,14 @@ def preprocess_program( program: itir.FencilDefinition, offset_provider: Mapping[str, Any], - lift_mode: itir_transforms.LiftMode, + lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, unroll_reduce: bool = False, ): - node = itir_transforms.apply_common_transforms( + node = legacy_itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -73,7 +76,7 @@ def build_sdfg_from_itir( auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, + lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] @@ -234,7 +237,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: } sdfg.offset_providers_per_input_field = {} - itir_tmp = itir_transforms.apply_common_transforms( + itir_tmp = legacy_itir_transforms.apply_common_transforms( self.itir, offset_provider=offset_provider ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 7a442e3819..740f1979cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import program_to_fencil from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -36,7 +36,6 @@ class DaCeTranslator( step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): auto_optimize: bool = False - lift_mode: LiftMode = LiftMode.FORCE_INLINE device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None temporary_extraction_heuristics: Optional[ @@ -69,7 +68,6 @@ def generate_sdfg( auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, - lift_mode=self.lift_mode, symbolic_domain_sizes=self.symbolic_domain_sizes, temporary_extraction_heuristics=self.temporary_extraction_heuristics, load_sdfg_from_file=False, @@ -82,7 +80,9 @@ def __call__( ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the ITIR definition.""" program: itir.FencilDefinition | itir.Program = inp.data - assert isinstance(program, itir.FencilDefinition) + + if isinstance(program, itir.Program): + program = program_to_fencil.program_to_fencil(program) sdfg = self.generate_sdfg( program, diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 4a788bf40c..965c6417b2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config from gt4py.next.common import Connectivity, Dimension -from gt4py.next.iterator import ir as itir, transforms +from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -166,19 +166,19 @@ class Params: cached_translation = factory.Trait( translation=factory.LazyAttribute( lambda o: workflow.CachedStep( - o.translation_, + o.bare_translation, hash_function=fingerprint_compilable_program, cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), ) - translation_ = factory.SubFactory( + bare_translation = factory.SubFactory( gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), ) - translation = factory.LazyAttribute(lambda o: o.translation_) + translation = factory.LazyAttribute(lambda o: o.bare_translation) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source @@ -213,12 +213,6 @@ class Params: ), name_cached="_cached", ) - use_temporaries = factory.Trait( - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, - # otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001 - name_temps="_with_temporaries", - ) device_type = core_defs.DeviceType.CPU hash_function = compilation_hash otf_workflow = factory.SubFactory( @@ -242,8 +236,10 @@ class Params: run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) -run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) - run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) + +run_gtfn_no_transforms = GTFNBackendFactory( + otf_workflow__bare_translation__enable_itir_transforms=False +) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 57785ceb33..4d518d7fcc 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -103,6 +103,7 @@ def fencil_generator( Arguments: ir: The iterator IR (ITIR) node. debug: Keep module source containing fencil implementation. + extract_temporaries: Extract intermediate field values into temporaries. use_embedded: Directly use builtins from embedded backend instead of generic dispatcher. Gives faster performance and is easier to debug. @@ -209,7 +210,7 @@ def decorated_fencil( ) -> None: if out is not None: args = (*args, out) - if not column_axis: + if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? column_axis = inp.args.column_axis fencil( *args, @@ -222,11 +223,13 @@ def decorated_fencil( return decorated_fencil +# TODO(tehrengruber): introduce factory default = next_backend.Backend( name="roundtrip", executor=Roundtrip( transforms=functools.partial( - itir_transforms.apply_common_transforms, lift_mode=itir_transforms.LiftMode.FORCE_INLINE + itir_transforms.apply_common_transforms, + extract_temporaries=False, ) ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), @@ -237,12 +240,18 @@ def decorated_fencil( executor=Roundtrip( transforms=functools.partial( itir_transforms.apply_common_transforms, - lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES, + extract_temporaries=True, ) ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) +no_transforms = next_backend.Backend( + name="roundtrip", + executor=Roundtrip(transforms=lambda o, *, offset_provider: o), + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms=next_backend.DEFAULT_TRANSFORMS, +) gtir = next_backend.Backend( @@ -257,3 +266,4 @@ def decorated_fencil( ), ), ) +foast_to_gtir_step = foast_to_gtir.adapted_foast_to_gtir_factory(cached=True) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5bda9a6f2e..66f8937dc5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -459,7 +459,9 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: """ if isinstance(symbol_type, ts.DeferredType) and ( - symbol_type.constraint is None or issubclass(type_class(to_type), symbol_type.constraint) + symbol_type.constraint is None + or (isinstance(to_type, ts.DeferredType) and to_type.constraint is None) + or issubclass(type_class(to_type), symbol_type.constraint) ): return True elif is_concrete(symbol_type): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 1bcc3554a7..c86ba88ead 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -43,11 +43,10 @@ def short_id(self, num_components: int = 2) -> str: class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn" GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative" - GTFN_CPU_WITH_TEMPORARIES = ( - "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" - ) + GTFN_CPU_NO_TRANSFORMS = "gt4py.next.program_processors.runners.gtfn.run_gtfn_no_transforms" GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + ROUNDTRIP_NO_TRANSFORMS = "gt4py.next.program_processors.runners.roundtrip.no_transforms" GTIR_EMBEDDED = "gt4py.next.program_processors.runners.roundtrip.gtir" ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -102,6 +101,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" @@ -130,13 +130,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] +# Markers to skip because of missing features in the domain inference +DOMAIN_INFERENCE_SKIP_LIST = [ + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), +] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -148,8 +153,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] -GTIR_DACE_SKIP_TEST_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), +GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), @@ -164,14 +168,22 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] -GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 - (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 - (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), +ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] +GTFN_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 + (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) @@ -192,20 +204,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [(ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramFormatterId.GTFN_CPP_FORMATTER: [ - (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) - ], - ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.GTIR_EMBEDDED: [ - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ProgramFormatterId.GTFN_CPP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], - ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + ProgramFormatterId.LISP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST, + ProgramBackendId.ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.DOUBLE_ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: ROUNDTRIP_SKIP_LIST + + [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 0ed3365969..c64efb27d2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -49,7 +49,6 @@ def __gt_allocator__( next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, pytest.param( next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu ), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index f26424bf0e..47419c278b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,9 +30,9 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, (itir.FencilDefinition, itir.Program)) + assert isinstance(testee.itir, (itir.Program, itir.FencilDefinition)) assert isinstance( - testee.with_backend(cartesian_case.backend).itir, (itir.FencilDefinition, itir.Program) + testee.with_backend(cartesian_case.backend).itir, (itir.Program, itir.FencilDefinition) ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 27f94960dc..f10f195d3a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -301,6 +301,21 @@ def testee(a: cases.IJKField, b: int32) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) +@pytest.mark.uses_tuple_args +def test_double_use_scalar(cartesian_case): + # TODO(tehrengruber): This should be a regression test on ITIR level, but tracing doesn't + # work for this case. + @gtx.field_operator + def testee(a: np.int32, b: np.int32, c: cases.IField) -> cases.IField: + tmp = a * b + tmp2 = tmp * tmp + # important part here is that we use the intermediate twice so that it is + # not inlined + return tmp2 * tmp2 * c + + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * c) + + @pytest.mark.uses_scalar_in_domain_and_fo def test_scalar_in_domain_spec_and_fo_call(cartesian_case): @gtx.field_operator @@ -688,9 +703,6 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.backend == gtfn.run_gtfn_with_temporaries: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( state: tuple[float, float], a: float, b: float, c: float, d: float @@ -789,9 +801,6 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan def test_ternary_scan(cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 @@ -814,9 +823,6 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.uses_scan_without_field_args @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] expected = np.arange(1, 1 + k_size, 1, dtype=int32) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 0efb599f9e..7ff7edf226 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -56,6 +56,7 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField cases.verify(cartesian_case, simple_if, a, b, condition, out=out, ref=a if condition else b) +# TODO(tehrengruber): test with fields on different domains @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) @pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 0305a5841a..11e28de9e1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -5,14 +5,15 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import platform import pytest from numpy import int32, int64 from gt4py import next as gtx from gt4py.next import backend, common -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from gt4py.next.iterator.transforms import apply_common_transforms +from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -34,8 +35,8 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, - executor=run_gtfn_with_temporaries.executor.replace( - translation=run_gtfn_with_temporaries.executor.translation.replace( + executor=run_gtfn.executor.replace( + translation=run_gtfn.executor.translation.replace( symbolic_domain_sizes={ "Cell": "num_cells", "Edge": "num_edges", @@ -43,7 +44,7 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): } ) ), - allocator=run_gtfn_with_temporaries.allocator, + allocator=run_gtfn.allocator, ) @@ -64,8 +65,14 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") + if platform.machine() == "x86_64": + pytest.xfail( + reason="The C++ code generated in this test contains unicode characters " + "(coming from the ssa pass) which is not supported by gcc 9 used" + "in the CI. Bumping the container version sadly did not work for" + "unrelated and unclear reasons. Since the issue is not present" + "on Alps we just skip the test for now before investing more time." + ) unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, @@ -100,12 +107,9 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") - itir_with_tmp = apply_common_transforms( testee.itir, - lift_mode=LiftMode.USE_TEMPORARIES, + extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c2f72e4ca7..3fc4ed9945 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -45,8 +45,9 @@ plus, shift, xor_, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -87,7 +88,9 @@ def dispatch(arg0): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0), domain, out) elif len(inps) == 2: @@ -102,7 +105,9 @@ def dispatch(arg0, arg1): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1), domain, out) elif len(inps) == 3: @@ -117,7 +122,9 @@ def dispatch(arg0, arg1, arg2): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, arg2, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1, arg2), domain, out) else: raise AssertionError("Add overload.") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index a86959d075..e462aa07eb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,7 +18,9 @@ @pytest.mark.uses_index_fields +@pytest.mark.uses_scan_in_stencil def test_scan_in_stencil(program_processor): + # FIXME[#1582](tehrengruber): Remove test after scan is reworked. program_processor, validate = program_processor isize = 1 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 505879a506..19664f2dab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -227,14 +227,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail( - "Needs implementation of scan projector. Breaks in type inference as executed" - "again after CollapseTuple." - ) if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") @@ -254,12 +246,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - cases.run( test_setup.case, solve_nonhydro_stencil_52_like, @@ -276,11 +262,6 @@ def test_solve_nonhydro_stencil_52_like(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 14271efb27..3ce9d6b470 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -78,11 +78,6 @@ def naive_lap(inp): def test_anton_toy(stencil, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn_with_temporaries.executor, - ]: - pytest.xfail("TODO: issue with temporaries that crashes the application") - if stencil is lap: pytest.xfail( "Type inference does not support calling lambdas with offset arguments of changing type." diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 2b858f3025..f8e9f22eff 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next import field_utils from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -170,23 +170,14 @@ def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_fu @fundef -def sum_scanpass(state, inp): +def ksum(state, inp): return state + deref(inp) -@fundef -def ksum(inp): - return scan(sum_scanpass, True, 0.0)(inp) - - @fendef(column_axis=KDim) def ksum_fencil(i_size, k_start, k_end, inp, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)), - ksum, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)) + set_at(as_fieldop(scan(ksum, True, 0.0), domain)(inp), domain, out) @pytest.mark.parametrize( @@ -214,19 +205,10 @@ def test_ksum_scan(program_processor, kstart, reference): assert np.allclose(reference, out.asnumpy()) -@fundef -def ksum_back(inp): - return scan(sum_scanpass, False, 0.0)(inp) - - @fendef(column_axis=KDim) def ksum_back_fencil(i_size, k_size, inp, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), - ksum_back, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)) + set_at(as_fieldop(scan(ksum, False, 0.0), domain)(inp), domain, out) def test_ksum_back_scan(program_processor): @@ -252,23 +234,14 @@ def test_ksum_back_scan(program_processor): @fundef -def doublesum_scanpass(state, inp0, inp1): +def kdoublesum(state, inp0, inp1): return make_tuple(tuple_get(0, state) + deref(inp0), tuple_get(1, state) + deref(inp1)) -@fundef -def kdoublesum(inp0, inp1): - return scan(doublesum_scanpass, True, make_tuple(0.0, 0))(inp0, inp1) - - @fendef(column_axis=KDim) def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)), - kdoublesum, - out, - [inp0, inp1], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)) + set_at(as_fieldop(scan(kdoublesum, True, make_tuple(0.0, 0)), domain)(inp0, inp1), domain, out) @pytest.mark.parametrize( @@ -325,7 +298,8 @@ def sum_shifted(inp0, inp1): @fendef(column_axis=KDim) def sum_shifted_fencil(out, inp0, inp1, k_size): - closure(cartesian_domain(named_range(KDim, 1, k_size)), sum_shifted, out, [inp0, inp1]) + domain = cartesian_domain(named_range(KDim, 1, k_size)) + set_at(as_fieldop(sum_shifted, domain)(inp0, inp1), domain, out) def test_different_vertical_sizes(program_processor): @@ -352,7 +326,8 @@ def sum(inp0, inp1): @fendef(column_axis=KDim) def sum_fencil(out, inp0, inp1, k_size): - closure(cartesian_domain(named_range(KDim, 0, k_size)), sum, out, [inp0, inp1]) + domain = cartesian_domain(named_range(KDim, 0, k_size)) + set_at(as_fieldop(sum, domain)(inp0, inp1), domain, out) @pytest.mark.uses_origin diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 156bc1c37f..3db4497910 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -30,7 +30,6 @@ unstructured_domain, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.iterator.transforms.pass_manager import LiftMode from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py index 2dde7d7653..c38a29bc61 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -28,7 +28,7 @@ from gt4py.next.iterator.runtime import set_at, if_stmt, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn -from next_tests.unit_tests.conftest import program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor_no_transforms, run_processor i = offset("i") @@ -43,8 +43,8 @@ def multiply(alpha, inp): @pytest.mark.uses_ir_if_stmts @pytest.mark.parametrize("cond", [True, False]) -def test_if_stmt(program_processor, cond): - program_processor, validate = program_processor +def test_if_stmt(program_processor_no_transforms, cond): + program_processor, validate = program_processor_no_transforms size = 10 @fendef(offset_provider={"i": IDim}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index a89f250571..30ceaf9376 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -10,9 +10,9 @@ import pytest import gt4py.next as gtx +from gt4py.next import backend from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.runtime import set_at, fendef, fundef from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters from gt4py.next.program_processors.runners import gtfn @@ -42,22 +42,17 @@ def tridiag_backward2(x_kp1, cp, dp): @fundef -def solve_tridiag(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward, False, 0.0)(cpdp) - - -def tuple_get_it(i, x): - def stencil(x): - return tuple_get(i, deref(x)) - - return lift(stencil)(x) +def solve_tridiag(domain, a, b, c, d): + cpdp = as_fieldop(scan(tridiag_forward, True, make_tuple(0.0, 0.0)), domain)(a, b, c, d) + return as_fieldop(scan(tridiag_backward, False, 0.0), domain)(cpdp) @fundef -def solve_tridiag2(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward2, False, 0.0)(tuple_get_it(0, cpdp), tuple_get_it(1, cpdp)) +def solve_tridiag2(domain, a, b, c, d): + cpdp = as_fieldop(scan(tridiag_forward, True, make_tuple(0.0, 0.0)), domain)(a, b, c, d) + return as_fieldop(scan(tridiag_backward2, False, 0.0), domain)( + tuple_get(0, cpdp), tuple_get(1, cpdp) + ) @pytest.fixture @@ -80,40 +75,27 @@ def tridiag_reference(): @fendef def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) - ), - solve_tridiag, - x, - [a, b, c, d], + domain = cartesian_domain( + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ) + set_at(solve_tridiag(domain, a, b, c, d), domain, x) @fendef def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) - ), - solve_tridiag2, - x, - [a, b, c, d], + domain = cartesian_domain( + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ) + set_at(solve_tridiag2(domain, a, b, c, d), domain, x) @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) -@pytest.mark.uses_lift_expressions def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn_formatters.format_cpp, - ]: - pytest.skip("gtfn does only support lifted scans when using temporaries") - if program_processor == gtfn.run_gtfn_with_temporaries: - pytest.xfail("tuple_get on columns not supported.") + + if isinstance(program_processor, backend.Backend) and "dace" in program_processor.name: + pytest.xfail("Dace ITIR backend doesn't support the IR format used in this test.") + a, b, c, d, x = tridiag_reference shape = a.shape as_3d_field = gtx.as_field.partial([IDim, JDim, KDim]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fb1d4c152..6fdc6a77a1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -383,7 +383,6 @@ def test_shift_sparse_input_field2(program_processor): if program_processor in [ gtfn.run_gtfn, gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8a4aa50730..ca66b45d6d 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -25,7 +25,30 @@ ProgramProcessor: TypeAlias = backend.Backend | program_formatter.ProgramFormatter -@pytest.fixture( +def _program_processor(request) -> tuple[ProgramProcessor, bool]: + """ + Fixture creating program processors on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ + processor_id, is_backend = request.param + if processor_id is None: + return None, is_backend + + processor = processor_id.load() + + for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( + processor_id, [] + ): + if marker == next_tests.definitions.ALL or request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=processor_id)) + + return processor, is_backend + + +program_processor = pytest.fixture( + _program_processor, params=[ (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), @@ -33,7 +56,6 @@ (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), - (next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), @@ -50,26 +72,16 @@ ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) -def program_processor(request) -> tuple[ProgramProcessor, bool]: - """ - Fixture creating program processors on-demand for tests. - - Notes: - Check ADR 15 for details on the test-exclusion matrices. - """ - processor_id, is_backend = request.param - if processor_id is None: - return None, is_backend - - processor = processor_id.load() - - for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( - processor_id, [] - ): - if marker == next_tests.definitions.ALL or request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=processor_id)) - return processor, is_backend +program_processor_no_transforms = pytest.fixture( + _program_processor, + params=[ + (None, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_NO_TRANSFORMS, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_NO_TRANSFORMS, True), + ], + ids=lambda p: p[0].short_id() if p[0] is not None else "None", +) def run_processor( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 720076c8c2..28090ff1e2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -20,6 +20,7 @@ def test_simple_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) expected = tuple_of_size_2 @@ -37,6 +38,7 @@ def test_nested_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == tup_of_size2_from_lambda @@ -52,6 +54,7 @@ def test_different_tuples_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -65,6 +68,7 @@ def test_incompatible_order_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -76,6 +80,7 @@ def test_incompatible_size_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -87,6 +92,7 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == im.make_tuple("first", "second") @@ -99,6 +105,7 @@ def test_simple_tuple_get_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, allow_undeclared_symbols=True, + within_stencil=False, ) assert expected == actual @@ -111,6 +118,7 @@ def test_propagate_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert expected == actual @@ -128,6 +136,7 @@ def test_letify_make_tuple_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -141,6 +150,7 @@ def test_letify_make_tuple_with_trivial_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -154,6 +164,7 @@ def test_inline_trivial_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -172,6 +183,7 @@ def test_propagate_to_if_on_tuples(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -189,6 +201,7 @@ def test_propagate_to_if_on_tuples_with_let(): flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -201,6 +214,7 @@ def test_propagate_nested_lift(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -211,7 +225,10 @@ def test_if_on_tuples_with_let(): )(im.tuple_get(0, "val")) expected = im.if_("pred", 1, 3) actual = CollapseTuple.apply( - testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True + testee, + remove_letified_make_tuple_elements=False, + allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -220,5 +237,5 @@ def test_tuple_get_on_untyped_ref(): # test pass gracefully handles untyped nodes. testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) - actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True) + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) assert actual == testee diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 3204b49371..e04856b75f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -37,7 +37,7 @@ def test_trivial(): ), args=[common], ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -45,7 +45,7 @@ def test_lambda_capture(): common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) testee = ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id="x")], expr=common), args=[common]) expected = testee - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -53,7 +53,7 @@ def test_lambda_no_capture(): common = im.plus("x", "y") testee = im.call(im.lambda_("z")(im.plus("x", "y")))(im.plus("x", "y")) expected = im.let("_cs_1", common)("_cs_1") - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -65,7 +65,7 @@ def common_expr(): testee = im.call(im.lambda_("x", "y")(common_expr()))(common_expr(), common_expr()) # (λ(_cs_1) → _cs_1 + _cs_1)(x + y) expected = im.let("_cs_1", common_expr())(im.plus("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -79,7 +79,7 @@ def common_expr(): expected = im.lambda_("x")( im.let("_cs_1", common_expr())(im.plus("z", im.plus("_cs_1", "_cs_1"))) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -93,7 +93,7 @@ def common_expr(): ) # (λ(_cs_1) → _cs_1(2) + _cs_1(3))(λ(a) → a + 1) expected = im.let("_cs_1", common_expr())(im.plus(im.call("_cs_1")(2), im.call("_cs_1")(3))) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -109,7 +109,7 @@ def common_expr(): expected = im.let("_cs_1", common_expr())( im.let("_cs_2", im.call("_cs_1")(2))(im.plus("_cs_2", "_cs_2")) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -133,7 +133,7 @@ def common_expr(): ) ) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -157,7 +157,7 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -178,7 +178,7 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -268,7 +268,7 @@ def test_no_extraction_outside_asfieldop(): identity_fieldop(im.ref("a", field_type)), identity_fieldop(im.ref("b", field_type)) ) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == testee @@ -289,5 +289,5 @@ def test_field_extraction_outside_asfieldop(): # ) expected = im.let("_cs_1", identity_fieldop(field))(plus_fieldop("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 50756f40e7..141091b450 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -84,9 +84,13 @@ def run_test_expr( domain: itir.FunCall, expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, + domain_utils.SymbolicDomain.from_expr(domain), + offset_provider, + symbolic_domain_sizes, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -1021,3 +1025,22 @@ def test_scan(offset_provider): {"a": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12)})}, offset_provider, ) + + +def test_symbolic_domain_sizes(unstructured_offset_provider): + stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) + domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) + symbolic_domain_sizes = {"Vertex": "num_vertices"} + + testee, expected = setup_test_as_fieldop( + stencil, + domain, + ) + run_test_expr( + testee, + expected, + domain, + {"in_field1": {Vertex: (0, im.ref("num_vertices"))}}, + unstructured_offset_provider, + symbolic_domain_sizes, + ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index da2c16336e..b5b9a62009 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -45,6 +45,31 @@ def test_trivial_literal(): assert actual == expected +def test_tuple_arg(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("t")(im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t"))), d)( + im.make_tuple(1, 2) + ), + 3, + ) + expected = im.as_fieldop( + im.lambda_()( + im.plus( + im.let("t", im.make_tuple(1, 2))( + im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t")) + ), + 3, + ) + ), + d, + )() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_symref_used_twice(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 09ed204a91..28bd88b853 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -26,93 +26,35 @@ def has_skip_values(request): @pytest.fixture def basic_reduction(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ) - ], - ) + return im.call(im.call("reduce")("foo", 0.0))(im.neighbors("Dim", "x")) @pytest.fixture def reduction_with_shift_on_second_arg(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.SymRef(id="x"), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], - ) + return im.call(im.call("reduce")("foo", 0.0))("x", im.neighbors("Dim", "y")) @pytest.fixture def reduction_with_incompatible_shifts(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim2"), ir.SymRef(id="y")], - ), - ], + return im.call(im.call("reduce")("foo", 0.0))( + im.neighbors("Dim", "x"), im.neighbors("Dim2", "y") ) @pytest.fixture def reduction_with_irrelevant_full_shift(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ - ir.OffsetLiteral(value="Dim"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="IrrelevantDim"), - ir.OffsetLiteral(value="0"), - ], - ), - args=[ir.SymRef(id="x")], - ), - ], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], + return im.call(im.call("reduce")("foo", 0.0))( + im.neighbors("Dim", im.shift("IrrelevantDim", 0)("x")), im.neighbors("Dim", "y") ) -# TODO add a test with lift +@pytest.fixture +def reduction_if(): + UIDs.reset_sequence() + return im.call(im.call("reduce")("foo", 0.0))(im.if_(True, im.neighbors("Dim", "x"), "y")) @pytest.mark.parametrize( @@ -121,6 +63,7 @@ def reduction_with_irrelevant_full_shift(): "basic_reduction", "reduction_with_irrelevant_full_shift", "reduction_with_shift_on_second_arg", + "reduction_if", ], ) def test_get_partial_offsets(reduction, request): @@ -178,6 +121,14 @@ def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, assert actual == expected +def test_reduction_with_if(reduction_if): + expected = _expected(reduction_if, "Dim", 2, False) + + offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + assert actual == expected + + def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) From aeff1e37bb483faebc280776e18b83287aacbe49 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 15 Nov 2024 15:07:22 +0100 Subject: [PATCH 06/43] refactor[catesian]: Type hints and code redability improvements (#1724) ## Description This PR is split off the work for the new GT4Py - DaCe bridge, which should allow to expose control flow statements (`if` and `while`) to DaCe to better use DaCe's analytics capabilities. This PR is concerned with adding type hints and generally improving code readability. Main parts are - `daceir_builder.py`: early returns and renamed variable - `sdfg_builder.py`: type hints and early returns - `tasklet_codegen.py`: type hints and early returns `TaskletCodegen` was given `sdfg_ctx`, which wasn't used. That parameter was thus removed. Parent issue: https://github.com/GEOS-ESM/NDSL/issues/53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. Assumed to be covered by existing tests. - [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <> --- src/gt4py/cartesian/gtc/common.py | 2 +- src/gt4py/cartesian/gtc/dace/daceir.py | 2 +- .../gtc/dace/expansion/daceir_builder.py | 84 +++++++++---------- .../gtc/dace/expansion/sdfg_builder.py | 36 ++++---- .../gtc/dace/expansion/tasklet_codegen.py | 64 +++++++------- src/gt4py/cartesian/gtc/dace/utils.py | 5 +- src/gt4py/eve/trees.py | 2 +- src/gt4py/eve/visitors.py | 4 +- 8 files changed, 98 insertions(+), 101 deletions(-) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index bfe434e7f3..dcb01db7ca 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -311,7 +311,7 @@ class CartesianOffset(eve.Node): k: int @classmethod - def zero(cls) -> "CartesianOffset": + def zero(cls) -> CartesianOffset: return cls(i=0, j=0, k=0) def to_dict(self) -> Dict[str, int]: diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 0ecb02b50f..78451c30f5 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -730,7 +730,7 @@ class Literal(common.Literal, Expr): class ScalarAccess(common.ScalarAccess, Expr): - name: eve.Coerced[eve.SymbolRef] + pass class VariableKOffset(common.VariableKOffset[Expr]): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index a8a3a3cb54..5f2007871e 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -74,8 +74,8 @@ def _get_tasklet_inout_memlets( *, get_outputs: bool, global_ctx: DaCeIRBuilder.GlobalContext, - **kwargs, -): + **kwargs: Any, +) -> List[dcir.Memlet]: access_infos = compute_dcir_access_infos( node, block_extents=global_ctx.library_node.get_extents, @@ -85,7 +85,7 @@ def _get_tasklet_inout_memlets( **kwargs, ) - res = list() + memlets: List[dcir.Memlet] = [] for name, offset, tasklet_symbol in _access_iter(node, get_outputs=get_outputs): access_info = access_infos[name] if not access_info.variable_offset_axes: @@ -95,26 +95,27 @@ def _get_tasklet_inout_memlets( axis, extent=(offset_dict[axis.lower()], offset_dict[axis.lower()]) ) - memlet = dcir.Memlet( - field=name, - connector=tasklet_symbol, - access_info=access_info, - is_read=not get_outputs, - is_write=get_outputs, + memlets.append( + dcir.Memlet( + field=name, + connector=tasklet_symbol, + access_info=access_info, + is_read=not get_outputs, + is_write=get_outputs, + ) ) - res.append(memlet) - return res + return memlets -def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval): - def all_statements_in_region(scope_nodes): +def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval: Any) -> bool: + def all_statements_in_region(scope_nodes: List[eve.Node]) -> bool: return all( isinstance(stmt, dcir.HorizontalRestriction) for tasklet in eve.walk_values(scope_nodes).if_isinstance(dcir.Tasklet) for stmt in tasklet.stmts ) - def all_regions_same(scope_nodes): + def all_regions_same(scope_nodes: List[eve.Node]) -> bool: return ( len( set( @@ -179,11 +180,11 @@ def _get_dcir_decl( oir_decl: oir.Decl = self.library_node.declarations[field] assert isinstance(oir_decl, oir.FieldDecl) dace_array = self.arrays[field] - for s in dace_array.strides: - for sym in dace.symbolic.symlist(s).values(): - symbol_collector.add_symbol(str(sym)) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) + for stride in dace_array.strides: + for symbol in dace.symbolic.symlist(stride).values(): + symbol_collector.add_symbol(str(symbol)) + for symbol in access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(symbol) return dcir.FieldDecl( name=field, @@ -236,11 +237,7 @@ def push_expansion_item(self, item: Union[Map, Loop]) -> DaCeIRBuilder.Iteration if not isinstance(item, (Map, Loop)): raise ValueError - if isinstance(item, Map): - iterations = item.iterations - else: - iterations = [item] - + iterations = item.iterations if isinstance(item, Map) else [item] grid_subset = self.grid_subset for it in iterations: axis = it.axis @@ -267,13 +264,13 @@ def pop(self) -> DaCeIRBuilder.IterationContext: class SymbolCollector: symbol_decls: Dict[str, dcir.SymbolDecl] = dataclasses.field(default_factory=dict) - def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32): + def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32) -> None: if name not in self.symbol_decls: self.symbol_decls[name] = dcir.SymbolDecl(name=name, dtype=dtype) else: assert self.symbol_decls[name].dtype == dtype - def remove_symbol(self, name: eve.SymbolRef): + def remove_symbol(self, name: eve.SymbolRef) -> None: if name in self.symbol_decls: del self.symbol_decls[name] @@ -304,11 +301,14 @@ def visit_HorizontalRestriction( symbol_collector.add_symbol(axis.iteration_symbol()) if bound.level == common.LevelMarker.END: symbol_collector.add_symbol(axis.domain_symbol()) + return dcir.HorizontalRestriction( mask=node.mask, body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs) ) - def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs): + def visit_VariableKOffset( + self, node: oir.VariableKOffset, **kwargs: Any + ) -> dcir.VariableKOffset: return dcir.VariableKOffset(k=self.visit(node.k, **kwargs)) def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> dcir.LocalScalarDecl: @@ -419,7 +419,7 @@ def visit_HorizontalExecution( symbol_collector: DaCeIRBuilder.SymbolCollector, loop_order, k_interval, - **kwargs, + **kwargs: Any, ): # skip type checking due to https://github.com/python/mypy/issues/5485 extent = global_ctx.library_node.get_extents(node) # type: ignore @@ -581,7 +581,7 @@ def to_dataflow( nodes = flatten_list(nodes) if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): return nodes - elif not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): + if not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): raise ValueError("Can't mix dataflow and state nodes on same level.") read_memlets, write_memlets, field_memlets = union_inout_memlets(nodes) @@ -615,10 +615,10 @@ def to_state(self, nodes, *, grid_subset: dcir.GridSubset): nodes = flatten_list(nodes) if all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): return nodes - elif all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): + if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): return [dcir.ComputationState(computations=nodes, grid_subset=grid_subset)] - else: - raise ValueError("Can't mix dataflow and state nodes on same level.") + + raise ValueError("Can't mix dataflow and state nodes on same level.") def _process_map_item( self, @@ -628,8 +628,8 @@ def _process_map_item( global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): + **kwargs: Any, + ) -> List[dcir.DomainMap]: grid_subset = iteration_ctx.grid_subset read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) scope_nodes = self.to_dataflow( @@ -723,11 +723,11 @@ def _process_loop_item( scope_nodes, item: Loop, *, - global_ctx, + global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): + **kwargs: Any, + ) -> List[dcir.DomainLoop]: grid_subset = union_node_grid_subsets(list(scope_nodes)) read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) scope_nodes = self.to_state(scope_nodes, grid_subset=grid_subset) @@ -793,14 +793,14 @@ def _process_loop_item( def _process_iteration_item(self, scope, item, **kwargs): if isinstance(item, Map): return self._process_map_item(scope, item, **kwargs) - elif isinstance(item, Loop): + if isinstance(item, Loop): return self._process_loop_item(scope, item, **kwargs) - else: - raise ValueError("Invalid expansion specification set.") + + raise ValueError("Invalid expansion specification set.") def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs - ): + self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs: Any + ) -> dcir.NestedSDFG: overall_extent = Extent.zeros(2) for he in node.walk_values().if_isinstance(oir.HorizontalExecution): overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 7b0f0ab7c4..6728ccaa7d 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -89,7 +89,7 @@ def visit_Memlet( scope_node: dcir.ComputationNode, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, node_ctx: StencilComputationSDFGBuilder.NodeContext, - connector_prefix="", + connector_prefix: str = "", symtable: ChainMap[eve.SymbolRef, dcir.Decl], ) -> None: field_decl = symtable[node.field] @@ -139,13 +139,12 @@ def visit_Tasklet( sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, node_ctx: StencilComputationSDFGBuilder.NodeContext, symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs, + **kwargs: Any, ) -> None: code = TaskletCodegen.apply_codegen( node, read_memlets=node.read_memlets, write_memlets=node.write_memlets, - sdfg_ctx=sdfg_ctx, symtable=symtable, ) @@ -177,7 +176,7 @@ def visit_Tasklet( tasklet, tasklet, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx ) - def visit_Range(self, node: dcir.Range, **kwargs) -> Dict[str, str]: + def visit_Range(self, node: dcir.Range, **kwargs: Any) -> Dict[str, str]: start, end = node.interval.to_dace_symbolic() return {node.var: str(dace.subsets.Range([(start, end - 1, node.stride)]))} @@ -187,7 +186,7 @@ def visit_DomainMap( *, node_ctx: StencilComputationSDFGBuilder.NodeContext, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: ndranges = { k: v @@ -248,7 +247,7 @@ def visit_DomainLoop( node: dcir.DomainLoop, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: sdfg_ctx = sdfg_ctx.add_loop(node.index_range) self.visit(node.loop_states, sdfg_ctx=sdfg_ctx, **kwargs) @@ -259,7 +258,7 @@ def visit_ComputationState( node: dcir.ComputationState, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: sdfg_ctx.add_state() read_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} @@ -289,7 +288,7 @@ def visit_FieldDecl( *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, non_transients: Set[eve.SymbolRef], - **kwargs, + **kwargs: Any, ) -> None: assert len(node.strides) == len(node.shape) sdfg_ctx.sdfg.add_array( @@ -307,7 +306,7 @@ def visit_SymbolDecl( node: dcir.SymbolDecl, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: if node.name not in sdfg_ctx.sdfg.symbols: sdfg_ctx.sdfg.add_symbol(node.name, stype=data_type_to_dace_typeclass(node.dtype)) @@ -319,7 +318,7 @@ def visit_NestedSDFG( sdfg_ctx: Optional[StencilComputationSDFGBuilder.SDFGContext] = None, node_ctx: Optional[StencilComputationSDFGBuilder.NodeContext] = None, symtable: ChainMap[eve.SymbolRef, Any], - **kwargs, + **kwargs: Any, ) -> dace.nodes.NestedSDFG: sdfg = dace.SDFG(node.label) inner_sdfg_ctx = StencilComputationSDFGBuilder.SDFGContext( @@ -365,13 +364,12 @@ def visit_NestedSDFG( StencilComputationSDFGBuilder._add_empty_edges( nsdfg, nsdfg, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx ) - else: - nsdfg = dace.nodes.NestedSDFG( - label=sdfg.label, - sdfg=sdfg, - inputs={memlet.connector for memlet in node.read_memlets}, - outputs={memlet.connector for memlet in node.write_memlets}, - symbol_mapping=symbol_mapping, - ) + return nsdfg - return nsdfg + return dace.nodes.NestedSDFG( + label=sdfg.label, + sdfg=sdfg, + inputs={memlet.connector for memlet in node.read_memlets}, + outputs={memlet.connector for memlet in node.write_memlets}, + symbol_mapping=symbol_mapping, + ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 696dc27387..8033c64710 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -31,7 +31,7 @@ def _visit_offset( *, access_info: dcir.FieldAccessInfo, decl: dcir.FieldDecl, - **kwargs, + **kwargs: Any, ) -> str: int_sizes: List[Optional[int]] = [] for i, axis in enumerate(access_info.axes()): @@ -60,27 +60,27 @@ def _visit_offset( res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) return str(res) - def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs): + def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs: Any) -> str: return self._visit_offset(node, **kwargs) - def visit_VariableKOffset(self, node: common.CartesianOffset, **kwargs): + def visit_VariableKOffset(self, node: dcir.VariableKOffset, **kwargs: Any) -> str: return self._visit_offset(node, **kwargs) def visit_IndexAccess( self, node: dcir.IndexAccess, *, - is_target, - sdfg_ctx, + is_target: bool, symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs, - ): + **kwargs: Any, + ) -> str: if is_target: memlets = kwargs["write_memlets"] else: # if this node is not a target, it will still use the symbol of the write memlet if the # field was previously written in the same memlet. memlets = kwargs["read_memlets"] + kwargs["write_memlets"] + try: memlet = next(mem for mem in memlets if mem.connector == node.name) except StopIteration: @@ -101,12 +101,12 @@ def visit_IndexAccess( ) ) index_strs.extend( - self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs) - for idx in node.data_index + self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index ) return f"{node.name}[{','.join(index_strs)}]" - def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs): + def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs: Any) -> str: + # Visiting order matters because targets must not contain the target symbols from the left visit right = self.visit(node.right, is_target=False, **kwargs) left = self.visit(node.left, is_target=True, **kwargs) return f"{left} = {right}" @@ -120,18 +120,18 @@ def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs): def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str: if builtin == common.BuiltInLiteral.TRUE: return "True" - elif builtin == common.BuiltInLiteral.FALSE: + if builtin == common.BuiltInLiteral.FALSE: return "False" raise NotImplementedError("Not implemented BuiltInLiteral encountered.") - def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs): + def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs: Any) -> str: value = self.visit(literal.value, in_idx=in_idx, **kwargs) if in_idx: return str(value) - else: - return "{dtype}({value})".format( - dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value - ) + + return "{dtype}({value})".format( + dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value + ) Cast = as_fmt("{dtype}({expr})") @@ -178,26 +178,26 @@ def visit_NativeFuncCall(self, call: common.NativeFuncCall, **kwargs: Any) -> st def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str: if dtype == common.DataType.BOOL: return "dace.bool_" - elif dtype == common.DataType.INT8: + if dtype == common.DataType.INT8: return "dace.int8" - elif dtype == common.DataType.INT16: + if dtype == common.DataType.INT16: return "dace.int16" - elif dtype == common.DataType.INT32: + if dtype == common.DataType.INT32: return "dace.int32" - elif dtype == common.DataType.INT64: + if dtype == common.DataType.INT64: return "dace.int64" - elif dtype == common.DataType.FLOAT32: + if dtype == common.DataType.FLOAT32: return "dace.float32" - elif dtype == common.DataType.FLOAT64: + if dtype == common.DataType.FLOAT64: return "dace.float64" raise NotImplementedError("Not implemented DataType encountered.") def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: if op == common.UnaryOperator.NOT: return " not " - elif op == common.UnaryOperator.NEG: + if op == common.UnaryOperator.NEG: return "-" - elif op == common.UnaryOperator.POS: + if op == common.UnaryOperator.POS: return "+" raise NotImplementedError("Not implemented UnaryOperator encountered.") @@ -207,16 +207,16 @@ def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: LocalScalarDecl = as_fmt("{name}: {dtype}") - def visit_Tasklet(self, node: dcir.Tasklet, **kwargs): + def visit_Tasklet(self, node: dcir.Tasklet, **kwargs: Any) -> str: return "\n".join(self.visit(node.decls, **kwargs) + self.visit(node.stmts, **kwargs)) def _visit_conditional( self, cond: Optional[Union[dcir.Expr, common.HorizontalMask]], body: List[dcir.Stmt], - keyword, - **kwargs, - ): + keyword: str, + **kwargs: Any, + ) -> str: mask_str = "" indent = "" if cond is not None and (cond_str := self.visit(cond, is_target=False, **kwargs)): @@ -226,16 +226,16 @@ def _visit_conditional( body_code = [indent + b for b in body_code] return "\n".join([mask_str, *body_code]) - def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs): + def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs: Any) -> str: return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs): + def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs: Any) -> str: return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - def visit_While(self, node: dcir.While, **kwargs): + def visit_While(self, node: dcir.While, **kwargs: Any) -> Any: return self._visit_conditional(cond=node.cond, body=node.body, keyword="while", **kwargs) - def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs): + def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs: Any) -> str: clauses: List[str] = [] for axis, interval in zip(dcir.Axis.dims_horizontal(), node.intervals): diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index b5c23d2735..517e80ceb3 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -333,10 +333,9 @@ def compute_dcir_access_infos( global_grid_subset=access_info.global_grid_subset, ) ) - else: - res = ctx.access_infos + return res - return res + return ctx.access_infos def make_dace_subset( diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 27f19d2670..c8e8658413 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -32,7 +32,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 28d1e2acf6..59b4ef0881 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -45,7 +45,7 @@ class NodeVisitor: 3. ``self.generic_visit()``. This dispatching mechanism is implemented in the main :meth:`visit` - method and can be overriden in subclasses. Therefore, a simple way to extend + method and can be overridden in subclasses. Therefore, a simple way to extend the behavior of a visitor is by inheriting from lightweight `trait` classes with a custom ``visit()`` method, which wraps the call to the superclass' ``visit()`` and adds extra pre and post visit logic. Check :mod:`eve.traits` @@ -82,7 +82,7 @@ def apply(cls, tree, init_var, foo, bar=5, **kwargs): Notes: If you want to apply changes to nodes during the traversal, - use the :class:`NodeMutator` subclass, which handles correctly + use the :class:`NodeTranslator` subclass, which handles correctly structural modifications of the visited tree. """ From ea8d9dbfefa16ac71dee5afa48367d7018861721 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 15:57:00 +0100 Subject: [PATCH 07/43] ci: Bump gitlab ci on todi to ubuntu 22.04, cuda 12.6.2, cupy 13.3.0 (#1727) We were using ubuntu 22.04 which shipped with gcc 9.x.x. In order to get something more recent with proper utf-8 support I bumped to 22.04 on todi. On daint strange hangs occured so I kept everything as is there. --- ci/base.Dockerfile | 10 ++++++---- ci/cscs-ci.yml | 8 ++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index d20d9ca6ef..ea7c4722c7 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -1,5 +1,6 @@ -ARG CUDA_VERSION=12.5.0 -FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +ARG CUDA_VERSION=12.6.2 +ARG UBUNTU_VERSION=22.04 +FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 @@ -22,7 +23,7 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \ tk-dev \ libffi-dev \ liblzma-dev \ - python-openssl \ + $( [ "${UBUNTU_VERSION}" = "20.04" ] && echo "python-openssl" || echo "python3-openssl" ) \ libreadline-dev \ git \ rustc \ @@ -55,4 +56,5 @@ RUN pyenv update && \ ENV PATH="/root/.pyenv/shims:${PATH}" ARG CUPY_PACKAGE=cupy-cuda12x -RUN pip install --upgrade pip setuptools wheel tox ${CUPY_PACKAGE}==12.3.0 +ARG CUPY_VERSION=13.3.0 +RUN pip install --upgrade pip setuptools wheel tox ${CUPY_PACKAGE}==${CUPY_VERSION} diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 7fcd65106d..e2833e3cd9 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -42,17 +42,21 @@ stages: DOCKERFILE: ci/base.Dockerfile # change to 'always' if you want to rebuild, even if target tag exists already (if-not-exists is the default, i.e. we could also skip the variable) CSCS_REBUILD_POLICY: if-not-exists - DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]' + DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]' .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: CUDA_VERSION: 11.2.2 CUPY_PACKAGE: cupy-cuda11x + CUPY_VERSION: 12.3.0 # latest version that supports cuda 11 + UBUNTU_VERSION: 20.04 # 22.04 hangs on daint in some tests for unknown reasons. .build_baseimage_aarch64: extends: [.container-builder-cscs-gh200, .build_baseimage] variables: - CUDA_VERSION: 12.4.1 + CUDA_VERSION: 12.6.2 CUPY_PACKAGE: cupy-cuda12x + CUPY_VERSION: 13.3.0 + UBUNTU_VERSION: 22.04 # TODO: enable CI job when Todi is back in operational state when: manual From a00154ada421a05a700343fc648693c0ce78efc8 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 18 Nov 2024 08:51:48 +0100 Subject: [PATCH 08/43] feat[next][dace]: Use offset_type to represent neighborhood information for local dimensions (#1734) This PR adopts the `offset_type` design concept implemented in #1703 for Embedded-GTIR and applies it to the DaCe-GTIR backend. The only functional change is that the if-builtin is now expected to return the exact same data type, including the same `offset_type` if a local dimension is present in the result field. This change required updates to `test_gtir_reduce_with_cond_neighbors`. --- .../gtir_builtin_translators.py | 155 +++++---- .../runners/dace_fieldview/gtir_dataflow.py | 315 ++++++++++-------- .../runners/dace_fieldview/gtir_sdfg.py | 18 +- .../runners/dace_fieldview/utility.py | 9 +- .../dace_tests/test_gtir_to_sdfg.py | 213 +++++------- 5 files changed, 352 insertions(+), 358 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index bb37440fe2..69aedf44d6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -44,15 +44,50 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. - gt_dtype: GT4Py type definition, which includes the field domain information. - local_offset: Provides information about the local dimension in`FieldType` data. - Set to 'None' for scalar data. Can be 'None' for `FieldType` data with - only global (horizontal or vertical) dimensions. + gt_type: GT4Py type definition, which includes the field domain information. """ dc_node: dace.nodes.AccessNode - gt_dtype: ts.FieldType | ts.ScalarType - local_offset: Optional[str] + gt_type: ts.FieldType | ts.ScalarType + + def get_local_view( + self, domain: FieldopDomain + ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + """Helper method to access a field in local view, given a field operator domain.""" + if isinstance(self.gt_type, ts.ScalarType): + return gtir_dataflow.MemletExpr( + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + ) + + if isinstance(self.gt_type, ts.FieldType): + indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) + for dim, _, _ in domain + } + local_dims = [ + dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL + ] + + if len(local_dims) == 0: + return gtir_dataflow.IteratorExpr( + self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + ) + + elif len(local_dims) == 1: + field_dtype = itir_ts.ListType( + element_type=self.gt_type.dtype, offset_type=local_dims[0] + ) + field_dims = [ + dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + ] + return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + + else: + raise ValueError( + f"Unexpected data field {self.dc_node.data} with more than one local dimension." + ) + + raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") FieldopDomain: TypeAlias = list[ @@ -111,31 +146,13 @@ def _parse_fieldop_arg( ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: """Helper method to visit an expression passed as argument to a field operator.""" - arg = sdfg_builder.visit( - node, - sdfg=sdfg, - head_state=state, - ) + arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) # arguments passed to field operator should be plain fields, not tuples of fields if not isinstance(arg, FieldopData): raise ValueError(f"Received {node} as argument to field operator, expected a field.") - if isinstance(arg.gt_dtype, ts.ScalarType): - return gtir_dataflow.MemletExpr(arg.dc_node, sbs.Indices([0])) - elif isinstance(arg.gt_dtype, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain - } - dims = arg.gt_dtype.dims + ( - # we add an extra anonymous dimension in the iterator definition to enable - # dereferencing elements in `ListType` - [gtx_common.Dimension("")] if isinstance(arg.gt_dtype.dtype, itir_ts.ListType) else [] - ) - return gtir_dataflow.IteratorExpr(arg.dc_node, dims, indices, arg.local_offset) - else: - raise NotImplementedError(f"Node type {type(arg.gt_dtype)} not supported.") + return arg.get_local_view(domain) def _get_field_shape( @@ -178,20 +195,27 @@ def _create_temporary_field( if isinstance(output_desc, dace.data.Array): assert isinstance(node_type.dtype, itir_ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) - field_dtype = node_type.dtype.element_type + assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): - field_dtype = node_type.dtype + assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) else: raise ValueError(f"Cannot create field for dace type {output_desc}.") # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, dace_utils.as_dace_type(field_dtype)) + temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) field_node = state.add_access(temp_name) - field_type = ts.FieldType(field_dims, node_type.dtype) - return FieldopData(field_node, field_type, local_offset=dataflow_output.result.local_offset) + if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): + field_dtype = dataflow_output.result.gt_dtype + else: + assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = dataflow_output.result.gt_dtype.element_type + assert dataflow_output.result.gt_dtype.offset_type is not None + field_dims.append(dataflow_output.result.gt_dtype.offset_type) + + return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -273,7 +297,6 @@ def translate_as_fieldop( if isinstance(node.type.dtype, itir_ts.ListType): assert isinstance(output_desc, dace.data.Array) - assert set(output_desc.offset) == {0} # additional local dimension for neighbors # TODO(phimuell): Investigate if we should swap the two. output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) @@ -383,7 +406,7 @@ def translate_broadcast_scalar( external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype), local_offset=None) + return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) def translate_if( @@ -439,14 +462,14 @@ def translate_if( head_state=false_state, ) - def make_temps(output_data: FieldopData) -> FieldopData: - desc = output_data.dc_node.desc(sdfg) - data_name, _ = sdfg.add_temp_transient_like(desc) - data_node = state.add_access(data_name) + def construct_output(inner_data: FieldopData) -> FieldopData: + inner_desc = inner_data.dc_node.desc(sdfg) + outer, _ = sdfg.add_temp_transient_like(inner_desc) + outer_node = state.add_access(outer) - return FieldopData(data_node, output_data.gt_dtype, output_data.local_offset) + return FieldopData(outer_node, inner_data.gt_type) - result_temps = gtx_utils.tree_map(make_temps)(true_br_args) + result_temps = gtx_utils.tree_map(construct_output)(true_br_args) fields: Iterable[tuple[FieldopData, FieldopData, FieldopData]] = zip( gtx_utils.flatten_nested_tuple((true_br_args,)), @@ -456,7 +479,10 @@ def make_temps(output_data: FieldopData) -> FieldopData: ) for true_br, false_br, temp in fields: - assert true_br.gt_dtype == false_br.gt_dtype + if true_br.gt_type != false_br.gt_type: + raise ValueError( + f"Different type of result fields on if-branches '{true_br.gt_type}' vs '{false_br.gt_type}'." + ) true_br_node = true_br.dc_node false_br_node = false_br.dc_node @@ -482,40 +508,31 @@ def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - sym_name: str, - sym_type: ts.DataType, + data_name: str, + data_type: ts.DataType, ) -> FieldopResult: - if isinstance(sym_type, ts.FieldType): - sym_node = state.add_access(sym_name) - local_dims = [dim for dim in sym_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] - if len(local_dims) > 1: - raise ValueError(f"Field {sym_name} has more than one local dimension.") - elif len(local_dims) == 1: - # we ensure that the name of the local dimension corresponds to a valid - # connectivity-based offset provider - local_offset = next(iter(local_dims)).value - assert isinstance( - sdfg_builder.get_offset_provider(local_offset), gtx_common.Connectivity - ) - else: - local_offset = None - return FieldopData(sym_node, sym_type, local_offset) - elif isinstance(sym_type, ts.ScalarType): - if sym_name in sdfg.symbols: - sym_node = _get_symbolic_value( - sdfg, state, sdfg_builder, sym_name, sym_type, temp_name=f"__{sym_name}" + if isinstance(data_type, ts.FieldType): + data_node = state.add_access(data_name) + return FieldopData(data_node, data_type) + + elif isinstance(data_type, ts.ScalarType): + if data_name in sdfg.symbols: + data_node = _get_symbolic_value( + sdfg, state, sdfg_builder, data_name, data_type, temp_name=f"__{data_name}" ) else: - sym_node = state.add_access(sym_name) - return FieldopData(sym_node, sym_type, local_offset=None) - elif isinstance(sym_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(sym_name, sym_type) + data_node = state.add_access(data_name) + return FieldopData(data_node, data_type) + + elif isinstance(data_type, ts.TupleType): + tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) return tuple( _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) for fname, ftype in tuple_fields ) + else: - raise NotImplementedError(f"Symbol type {type(sym_type)} not supported.") + raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") def _get_symbolic_value( @@ -562,7 +579,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type, local_offset=None) + return FieldopData(data_node, data_type) def translate_make_tuple( @@ -646,7 +663,7 @@ def translate_scalar_expr( sdfg=sdfg, head_state=state, ) - if not (isinstance(arg, FieldopData) and isinstance(arg.gt_dtype, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(arg.gt_type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" args.append(arg.dc_node) @@ -691,7 +708,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type, local_offset=None) + return FieldopData(temp_node, node.type) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cf91d15aba..73b6e2ed4c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -29,6 +29,12 @@ from gt4py.next.type_system import type_info as ti, type_specifications as ts +# Magic local dimension for the result of a `make_const_list`. +# A clean implementation will probably involve to tag the `make_const_list` +# with the neighborhood it is meant to be used with. +_CONST_DIM = gtx_common.Dimension(value="_CONST_DIM", kind=gtx_common.DimensionKind.LOCAL) + + @dataclasses.dataclass(frozen=True) class ValueExpr: """ @@ -41,15 +47,12 @@ class ValueExpr: the result of a field operator, basically the data storage outside a global map. Args: - dc_node: Access node to the data storage, can be either a scalar or a local list. - gt_dtype: GT4Py type definition, which includes the field domain information. - local_offset: Provides information about the local dimension in`FieldType` data. - For a more detailed explanation see `gtir_builtin_translators.FieldopData`. + dc_node: Access node to the data container, can be either a scalar or a local list. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. """ dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - local_offset: Optional[str] = None @dataclasses.dataclass(frozen=True) @@ -58,15 +61,14 @@ class MemletExpr: Scalar or array data access through a memlet. Args: - dc_node: Access node to the data storage, can be either a scalar or a local list. + dc_node: Access node to the data container, can be either a scalar or a local list. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. subset: Represents the subset to use in memlet to access the above data. - local_offset: Provides information about the local dimension in`FieldType` data. - For a more detailed explanation see `gtir_builtin_translators.FieldopData`. """ dc_node: dace.nodes.AccessNode + gt_dtype: itir_ts.ListType | ts.ScalarType subset: sbs.Indices | sbs.Range - local_offset: Optional[str] = None @dataclasses.dataclass(frozen=True) @@ -87,19 +89,17 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. dimensions: Field domain represented as a sorted list of dimensions, needed to order the map index variables and dereference an element in the field. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. - local_offset: Provides information about the local dimension in`FieldType` data. - For a more detailed explanation see `gtir_builtin_translators.FieldopData`. - """ field: dace.nodes.AccessNode + gt_dtype: itir_ts.ListType | ts.ScalarType dimensions: list[gtx_common.Dimension] indices: dict[gtx_common.Dimension, DataExpr] - local_offset: Optional[str] = None class DataflowInputEdge(Protocol): @@ -383,18 +383,18 @@ def _construct_tasklet_result( dc_dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str, - local_offset: Optional[str] = None, use_array: bool = False, ) -> ValueExpr: - temp_name = self.sdfg.temp_data_name() + data_type = dace_utils.as_itir_type(dc_dtype) if use_array: # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) # in order to allow for composition of array shape in external memlets. - self.sdfg.add_array(temp_name, (1,), dc_dtype, transient=True) + temp_name, _ = self.sdfg.add_temp_transient((1,), dc_dtype) else: + temp_name = self.sdfg.temp_data_name() self.sdfg.add_scalar(temp_name, dc_dtype, transient=True) - data_type = dace_utils.as_itir_type(dc_dtype) + temp_node = self.state.add_access(temp_name) self._add_edge( src_node, @@ -403,7 +403,14 @@ def _construct_tasklet_result( None, dace.Memlet(data=temp_name, subset="0"), ) - return ValueExpr(temp_node, data_type, local_offset) + return ValueExpr( + dc_node=temp_node, + gt_dtype=( + itir_ts.ListType(element_type=data_type, offset_type=_CONST_DIM) + if use_array + else data_type + ), + ) def _visit_deref(self, node: gtir.FunCall) -> DataExpr: """ @@ -435,81 +442,87 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # deref a zero-dimensional field assert len(arg_expr.dimensions) == 0 assert isinstance(node.type, ts.ScalarType) - return MemletExpr(arg_expr.field, subset="0") + return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") + # default case: deref a field with one or more dimensions - assert len(field_desc.shape) == len(arg_expr.dimensions) if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): # when all indices are symblic expressions, we can perform direct field access through a memlet + if isinstance(arg_expr.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 + assert arg_expr.gt_dtype.offset_type is not None + field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] + else: + assert len(field_desc.shape) == len(arg_expr.dimensions) + field_dims = arg_expr.dimensions + field_subset = sbs.Range( (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] if dim in arg_expr.indices else (0, size - 1, 1) - for dim, size in zip(arg_expr.dimensions, field_desc.shape) + for dim, size in zip(field_dims, field_desc.shape) ) - return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) + return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) - else: - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] - index_connectors = [ - IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - if not isinstance(index, SymbolExpr) - ] - # here `internals` refer to the names used as index in the tasklet code string: - # an index can be either a connector name (for dynamic/indirect indices) - # or a symbol value (for literal values and scalar arguments). - index_internals = ",".join( - str(index.value) - if isinstance(index, SymbolExpr) - else IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - ) - deref_node = self._add_tasklet( - "runtime_deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", - ) - # add new termination point for the field parameter - self._add_input_data_edge( - arg_expr.field, - sbs.Range.from_array(field_desc), - deref_node, - "field", - ) + # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, + # either indirection through connectivity table or dynamic cartesian offset. + assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) + assert len(field_desc.shape) == len(arg_expr.dimensions) + field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + index_connectors = [ + IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + if not isinstance(index, SymbolExpr) + ] + # here `internals` refer to the names used as index in the tasklet code string: + # an index can be either a connector name (for dynamic/indirect indices) + # or a symbol value (for literal values and scalar arguments). + index_internals = ",".join( + str(index.value) + if isinstance(index, SymbolExpr) + else IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + ) + deref_node = self._add_tasklet( + "runtime_deref", + {"field"} | set(index_connectors), + {"val"}, + code=f"val = field[{index_internals}]", + ) + # add new termination point for the field parameter + self._add_input_data_edge( + arg_expr.field, + sbs.Range.from_array(field_desc), + deref_node, + "field", + ) - for dim, index_expr in field_indices: - # add termination points for the dynamic iterator indices - deref_connector = IndexConnectorFmt.format(dim=dim.value) - if isinstance(index_expr, MemletExpr): - self._add_input_data_edge( - index_expr.dc_node, - index_expr.subset, - deref_node, - deref_connector, - ) + for dim, index_expr in field_indices: + # add termination points for the dynamic iterator indices + deref_connector = IndexConnectorFmt.format(dim=dim.value) + if isinstance(index_expr, MemletExpr): + self._add_input_data_edge( + index_expr.dc_node, + index_expr.subset, + deref_node, + deref_connector, + ) - elif isinstance(index_expr, ValueExpr): - self._add_edge( - index_expr.dc_node, - None, - deref_node, - deref_connector, - dace.Memlet(data=index_expr.dc_node.data, subset="0"), - ) - else: - assert isinstance(index_expr, SymbolExpr) + elif isinstance(index_expr, ValueExpr): + self._add_edge( + index_expr.dc_node, + None, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.dc_node.data, subset="0"), + ) + else: + assert isinstance(index_expr, SymbolExpr) - return self._construct_tasklet_result( - field_desc.dtype, deref_node, "val", arg_expr.local_offset - ) + return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: - assert len(node.args) == 2 assert isinstance(node.type, itir_ts.ListType) + assert len(node.args) == 2 assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value @@ -543,8 +556,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: # to the view nodes. The simplify pass will remove the redundant access nodes. field_slice = self._construct_local_view( MemletExpr( - it.field, - sbs.Range.from_string( + dc_node=it.field, + gt_dtype=node.type, + subset=sbs.Range.from_string( ",".join( it.indices[dim].value # type: ignore[union-attr] if dim != offset_provider.neighbor_axis @@ -556,8 +570,11 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) connectivity_slice = self._construct_local_view( MemletExpr( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), + dc_node=self.state.add_access(connectivity), + gt_dtype=node.type, + subset=sbs.Range.from_string( + f"{origin_index.value}, 0:{offset_provider.max_neighbors}" + ), ) ) @@ -565,8 +582,8 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: (offset_provider.max_neighbors,), field_desc.dtype ) neighbors_node = self.state.add_access(neighbors_temp) - - neighbor_idx = dace_gtir_utils.get_map_variable(offset) + offset_type = gtx_common.Dimension(offset, gtx_common.DimensionKind.LOCAL) + neighbor_idx = dace_gtir_utils.get_map_variable(offset_type) index_connector = "__index" output_connector = "__val" @@ -604,7 +621,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: external_edges=True, ) - return ValueExpr(neighbors_node, node.type, offset) + return ValueExpr( + dc_node=neighbors_node, gt_dtype=itir_ts.ListType(node.type.element_type, offset_type) + ) def _visit_map(self, node: gtir.FunCall) -> ValueExpr: """ @@ -629,8 +648,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type.element_type, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(node.type.element_type) - input_args = [self.visit(arg) for arg in node.args] - input_connectors = [f"__arg{i}" for i in range(len(input_args))] + input_connectors = [f"__arg{i}" for i in range(len(node.args))] output_connector = "__out" # Here we build the body of the tasklet @@ -638,27 +656,37 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: fun_python_code = gtir_python_codegen.get_source(fun_node) tasklet_expression = f"{output_connector} = {fun_python_code}" - input_local_offsets = [ - input_arg.local_offset for input_arg in input_args if input_arg.local_offset is not None - ] - if len(input_local_offsets) == 0: + input_args = [self.visit(arg) for arg in node.args] + input_connectivities: dict[gtx_common.Dimension, gtx_common.Connectivity] = {} + for input_arg in input_args: + assert isinstance(input_arg.gt_dtype, itir_ts.ListType) + assert input_arg.gt_dtype.offset_type is not None + offset_type = input_arg.gt_dtype.offset_type + if offset_type == _CONST_DIM: + # this input argument is the result of `make_const_list` + continue + offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) + assert isinstance(offset_provider, gtx_common.Connectivity) + input_connectivities[offset_type] = offset_provider + + if len(input_connectivities) == 0: raise ValueError(f"Missing information on local dimension for map node {node}.") # GT4Py guarantees that all connectivities used to generate lists of neighbors # have the same length, that is the same value of 'max_neighbors'. - local_connectivities = dace_utils.filter_connectivities( - { - offset: self.subgraph_builder.get_offset_provider(offset) - for offset in input_local_offsets - } - ) - if len(set(table.max_neighbors for table in local_connectivities.values())) != 1: - raise ValueError( - "Unexpected arguments to map expression with different local dimensions." + if ( + len( + set( + (conn.has_skip_values, conn.max_neighbors) + for conn in input_connectivities.values() + ) ) - local_offset, offset_provider = next(iter(local_connectivities.items())) + != 1 + ): + raise ValueError("Unexpected arguments to map expression with different neighborhood.") + offset_type, offset_provider = next(iter(input_connectivities.items())) local_size = offset_provider.max_neighbors - map_index = dace_gtir_utils.get_map_variable(local_offset) + map_index = dace_gtir_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. # These edges are described as set of nodes, that will have to be connected to @@ -668,47 +696,31 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: # than representing map-to-map edges (which require memlets with 2 pass-nodes). input_memlets = {} input_nodes = {} - skip_value_connectivities: dict[str, gtx_common.Connectivity] = {} - for conn, input_expr in zip(input_connectors, input_args): - input_node = self._construct_local_view(input_expr).dc_node + for conn, input_arg in zip(input_connectors, input_args): + input_node = self._construct_local_view(input_arg).dc_node input_desc = input_node.desc(self.sdfg) # we assume that there is a single local dimension if len(input_desc.shape) != 1: raise ValueError(f"More than one local dimension in map expression {node}.") input_size = input_desc.shape[0] if input_size == 1: + assert input_arg.gt_dtype.offset_type == _CONST_DIM input_memlets[conn] = dace.Memlet(data=input_node.data, subset="0") - elif input_size != local_size: + elif input_size == local_size: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) + else: raise ValueError( f"Argument to map node with local size {input_size}, expected {local_size}." ) - else: - assert input_expr.local_offset - input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) - input_nodes[input_node.data] = input_node result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) result_node = self.state.add_access(result) - skip_value_connectivities = { - offset: offset_provider - for offset, offset_provider in local_connectivities.items() - if offset_provider.has_skip_values - } - - if len(skip_value_connectivities) == 0: - result_offset = local_offset - else: - # In case one or more of input expressions contain skip values, we use + if offset_provider.has_skip_values: + # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. - # Therefore, the result of map computation will also contain skip values. - # GT4Py guarantees that the skip values are placed in the same positions - # for all input expressions. - - result_offset, offset_provider = next(iter(skip_value_connectivities.items())) - - connectivity = dace_utils.connectivity_identifier(result_offset) + connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False @@ -716,8 +728,13 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: connectivity_slice = self._construct_local_view( MemletExpr( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + dc_node=self.state.add_access(connectivity), + gt_dtype=itir_ts.ListType( + element_type=node.type.element_type, offset_type=offset_type + ), + subset=sbs.Range.from_string( + f"{origin_map_index}, 0:{offset_provider.max_neighbors}" + ), ) ) @@ -749,7 +766,10 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: external_edges=True, ) - return ValueExpr(result_node, dc_dtype, result_offset) + return ValueExpr( + dc_node=result_node, + gt_dtype=itir_ts.ListType(node.type.element_type, offset_type), + ) def _make_reduce_with_skip_values( self, @@ -774,8 +794,12 @@ def _make_reduce_with_skip_values( """ origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) - assert input_expr.local_offset is not None - connectivity = dace_utils.connectivity_identifier(input_expr.local_offset) + assert ( + isinstance(input_expr.gt_dtype, itir_ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + offset_type = input_expr.gt_dtype.offset_type + connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_node = self.state.add_access(connectivity) connectivity_desc = connectivity_node.desc(self.sdfg) connectivity_desc.transient = False @@ -881,8 +905,12 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: input_expr = self.visit(node.args[0]) assert isinstance(input_expr, (MemletExpr, ValueExpr)) - assert input_expr.local_offset is not None - offset_provider = self.subgraph_builder.get_offset_provider(input_expr.local_offset) + assert ( + isinstance(input_expr.gt_dtype, itir_ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + offset_type = input_expr.gt_dtype.offset_type + offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) assert isinstance(offset_provider, gtx_common.Connectivity) if offset_provider.has_skip_values: @@ -998,9 +1026,13 @@ def _make_cartesian_shift( # a new iterator with a shifted index along one dimension return IteratorExpr( - it.field, - it.dimensions, - {dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items()}, + field=it.field, + gt_dtype=it.gt_dtype, + dimensions=it.dimensions, + indices={ + dim: (new_index if dim == offset_dim else index) + for dim, index in it.indices.items() + }, ) def _make_dynamic_neighbor_offset( @@ -1068,8 +1100,9 @@ def _make_unstructured_shift( if isinstance(offset_expr, SymbolExpr): # use memlet to retrieve the neighbor index shifted_indices[neighbor_dim] = MemletExpr( - offset_table_node, - sbs.Indices([origin_index.value, offset_expr.value]), + dc_node=offset_table_node, + gt_dtype=it.gt_dtype, + subset=sbs.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1077,7 +1110,9 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr(it.field, it.dimensions, shifted_indices) + return IteratorExpr( + field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices + ) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index da940e883c..ad8f490f12 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -344,9 +344,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData( - temp_node, field.gt_dtype, field.local_offset - ) + return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -489,9 +487,9 @@ def visit_SetAt( target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient - if isinstance(target.gt_dtype, ts.FieldType): + if isinstance(target.gt_type, ts.FieldType): subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_dtype.dims + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) else: assert len(domain) == 0 @@ -582,7 +580,7 @@ def visit_Lambda( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_dtype + pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } @@ -742,9 +740,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData( - outer_node, inner_data.gt_dtype, inner_data.local_offset - ) + outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -753,9 +749,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData( - outer_node, inner_data.gt_dtype, inner_data.local_offset - ) + outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index baae8a6ccd..caec6cd87e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -19,15 +19,10 @@ from gt4py.next.type_system import type_specifications as ts -def get_map_variable(dim: gtx_common.Dimension | str) -> str: +def get_map_variable(dim: gtx_common.Dimension) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. """ - if not isinstance(dim, gtx_common.Dimension): - if len(dim) != 0: - dim = gtx_common.Dimension(dim, gtx_common.DimensionKind.LOCAL) - else: - raise ValueError("Dimension name cannot be empty.") suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" return f"i_{dim.value}_gtx_{dim.kind}{suffix}" @@ -68,7 +63,7 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. """ return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_dtype for d in data] + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index cc72adae4f..a94157ecd2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1352,7 +1352,9 @@ def test_gtir_reduce_with_skip_values(): e = np.random.rand(SKIP_VALUE_MESH.num_edges) v_ref = [ functools.reduce( - lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + lambda x, y: x + y, + [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], + init_value, ) for v2e_neighbors in connectivity_V2E.table ] @@ -1394,120 +1396,74 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - # create mesh with skip values - connectivity_V2E_skip = copy.deepcopy(connectivity_V2E) - connectivity_V2E_skip.has_skip_values = True - connectivity_V2E_skip.table = np.asarray( - [ - [x if i != skip_idx else gtx_common._DEFAULT_SKIP_VALUE for i, x in enumerate(row)] - for skip_idx, row in zip( - np.random.randint(0, connectivity_V2E.max_neighbors, size=SIMPLE_MESH.num_vertices), - connectivity_V2E.table, - strict=True, - ) - ], - dtype=connectivity_V2E.table.dtype, - ) - # safety check that the connectivity table actually contains skip values - assert len(np.where(connectivity_V2E.table == gtx_common._DEFAULT_SKIP_VALUE)) != 0 - - offset_provider = SIMPLE_MESH_OFFSET_PROVIDER | { - "V2E_skip": connectivity_V2E_skip, - } - - V2E_SKIP_SYMBOLS = dict( - __connectivity_V2E_skip_size_0=SIMPLE_MESH.num_vertices, - __connectivity_V2E_skip_size_1=connectivity_V2E_skip.max_neighbors, - __connectivity_V2E_skip_stride_0=connectivity_V2E_skip.max_neighbors, - __connectivity_V2E_skip_stride_1=1, - ) - - e = np.random.rand(SIMPLE_MESH.num_edges) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + e = np.random.rand(SKIP_VALUE_MESH.num_edges) + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ functools.reduce( lambda x, y: x + y, map( lambda x: 0.0 if x[1] == gtx_common._DEFAULT_SKIP_VALUE else x[0], - zip((e[v2e_neighbors] * e[v2e_skip_neighbors]) + 1.0, v2e_skip_neighbors), + zip((e[v2e_neighbors] * v2e_values) + 1.0, v2e_neighbors), ), init_value, ) - for v2e_neighbors, v2e_skip_neighbors in zip( - connectivity_V2E.table, connectivity_V2E_skip.table - ) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field) ] - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.map_("plus")( - im.map_("multiplies")( - im.neighbors("V2E", "it"), - im.neighbors("V2E_skip", "it"), + testee = gtir.Program( + id=f"reduce_dot_product", + function_definitions=[], + params=[ + gtir.Sym(id="v2e_field", type=V2E_FTYPE), + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) ), - im.call("make_const_list")(1.0), + vertex_domain, ) - ) - ), - vertex_domain, - ) - )("edges") - - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) - )( - im.op_as_fieldop(im.map_("plus"), vertex_domain)( - im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - im.as_fieldop_neighbors("V2E_skip", "edges", vertex_domain), - ), - im.op_as_fieldop("make_const_list", vertex_domain)(1.0), - ) + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + "v2e_field", + ), + im.op_as_fieldop("make_const_list", vertex_domain)(1.0), + ) + ), + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], ) - for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): - testee = gtir.Program( - id=f"reduce_dot_product_{i}", - function_definitions=[], - params=[ - gtir.Sym(id="edges", type=EFTYPE), - gtir.Sym(id="vertices", type=VFTYPE), - gtir.Sym(id="nvertices", type=SIZE_TYPE), - ], - declarations=[], - body=[ - gtir.SetAt( - expr=stencil, - domain=vertex_domain, - target=gtir.SymRef(id="vertices"), - ) - ], - ) - - sdfg = dace_backend.build_sdfg_from_gtir(testee, offset_provider) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) - sdfg( - e, - v, - connectivity_V2E=connectivity_V2E.table, - connectivity_V2E_skip=connectivity_V2E_skip.table, - **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), - **V2E_SKIP_SYMBOLS, - ) - assert np.allclose(v, v_ref) + sdfg( + v2e_field, + e, + v, + connectivity_V2E=connectivity_V2E.table, + **make_mesh_symbols(SKIP_VALUE_MESH), + __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v, v_ref) def test_gtir_reduce_with_cond_neighbors(): @@ -1518,6 +1474,7 @@ def test_gtir_reduce_with_cond_neighbors(): function_definitions=[], params=[ gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="v2e_field", type=V2E_FTYPE), gtir.Sym(id="edges", type=EFTYPE), gtir.Sym(id="vertices", type=VFTYPE), gtir.Sym(id="nvertices", type=SIZE_TYPE), @@ -1535,7 +1492,7 @@ def test_gtir_reduce_with_cond_neighbors(): )( im.if_( "pred", - im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), + "v2e_field", im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) ), @@ -1545,49 +1502,45 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E_simple = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] - assert isinstance(connectivity_V2E_simple, gtx_common.NeighborTable) - connectivity_V2E_skip_values = copy.deepcopy(SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"]) - assert isinstance(connectivity_V2E_skip_values, gtx_common.NeighborTable) - assert SKIP_VALUE_MESH.num_vertices <= SIMPLE_MESH.num_vertices - connectivity_V2E_skip_values.table = np.concatenate( - ( - connectivity_V2E_skip_values.table[:, 0 : connectivity_V2E_simple.max_neighbors], - connectivity_V2E_simple.table[SKIP_VALUE_MESH.num_vertices :, :], - ), - axis=0, - ) - connectivity_V2E_skip_values.max_neighbors = connectivity_V2E_simple.max_neighbors + connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - e = np.random.rand(SIMPLE_MESH.num_edges) + v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + e = np.random.rand(SKIP_VALUE_MESH.num_edges) - for use_full in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir( - testee, - SIMPLE_MESH_OFFSET_PROVIDER | {"V2E_FULL": connectivity_V2E_skip_values}, - ) + for use_sparse in [False, True]: + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ functools.reduce( - lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + lambda x, y: x + y, + [ + v if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 + for i, v in zip(v2e_neighbors, v2e_values, strict=True) + ], + init_value, ) - for v2e_neighbors in ( - connectivity_V2E_simple.table if use_full else connectivity_V2E_skip_values.table + if use_sparse + else functools.reduce( + lambda x, y: x + y, + [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], + init_value, ) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) ] sdfg( - np.bool_(use_full), + np.bool_(use_sparse), + v2e_field, e, v, - connectivity_V2E=connectivity_V2E_skip_values.table, - connectivity_V2E_FULL=connectivity_V2E_simple.table, + connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), - __connectivity_V2E_FULL_size_0=SIMPLE_MESH.num_edges, - __connectivity_V2E_FULL_size_1=connectivity_V2E_skip_values.max_neighbors, - __connectivity_V2E_FULL_stride_0=connectivity_V2E_skip_values.max_neighbors, - __connectivity_V2E_FULL_stride_1=1, + **make_mesh_symbols(SKIP_VALUE_MESH), + __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) From a9a99928c1b1ba5c05234e45e36ccf0ac7c79214 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 18 Nov 2024 15:21:01 +0100 Subject: [PATCH 09/43] feat[next]: Upgrade dace dependency to v1.0.0 (#1740) DaCe version upgraded to `1.0.0`. It is also constrained to `< 1.1.0 `because the plan for DaCe v1.x is to introduce some breaking changes. GPU tests still fail with GTIR DaCe backend (`test_double_use_scalar`) so they will be enabled in a separate PR. Additional changes: - Removed limitation on SymPy version since DaCe is now compatible with SymPy v1.13 - CUDA version upgraded from 11.2 to 11.4 to avoid this compile error in gpu build: `dace/codegen/../runtime/include/dace/math.h(499): error: A __device__ variable cannot be marked constexpr` --- .pre-commit-config.yaml | 2 +- ci/cscs-ci.yml | 2 +- constraints.txt | 13 ++++++------- min-extra-requirements-test.txt | 3 +-- pyproject.toml | 2 +- requirements-dev.txt | 13 ++++++------- .../dace_fieldview/transformations/loop_blocking.py | 4 ---- 7 files changed, 16 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07f75177ea..1c3b6e693f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.3 + rev: v0.7.4 ##[[[end]]] hooks: # Run the linter. diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index e2833e3cd9..7adb88459e 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -46,7 +46,7 @@ stages: .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: - CUDA_VERSION: 11.2.2 + CUDA_VERSION: 11.4.3 CUPY_PACKAGE: cupy-cuda11x CUPY_VERSION: 12.3.0 # latest version that supports cuda 11 UBUNTU_VERSION: 20.04 # 22.04 hangs on daint in some tests for unknown reasons. diff --git a/constraints.txt b/constraints.txt index 4247f4951d..b4b8bc00d4 100644 --- a/constraints.txt +++ b/constraints.txt @@ -33,7 +33,7 @@ contourpy==1.1.1 # via matplotlib coverage==7.6.1 # via -r requirements-dev.in, pytest-cov cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) -dace==0.16.1 # via gt4py (pyproject.toml) +dace==1.0.0 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython @@ -50,7 +50,7 @@ factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat filelock==3.16.1 # via tox, virtualenv -fonttools==4.54.1 # via matplotlib +fonttools==4.55.0 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython @@ -67,7 +67,7 @@ iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake ipython==8.12.3 # via ipykernel jedi==0.19.2 # via ipython -jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx +jinja2==3.1.4 # via gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema jupyter-client==8.6.3 # via ipykernel, nbclient @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff -packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.4 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.3 # via -r requirements-dev.in +ruff==0.7.4 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -157,7 +157,7 @@ sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach -sympy==1.12.1 # via dace, gt4py (pyproject.toml) +sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) tach==0.14.3 # via -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox @@ -173,7 +173,6 @@ urllib3==2.2.3 # via requests virtualenv==20.27.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit -websockets==13.1 # via dace wheel==0.45.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 4190570105..57c0d3969d 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -61,7 +61,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.1 -dace==0.16.1 +dace==1.0.0 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -101,7 +101,6 @@ scipy==1.9.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.9 tabulate==0.8.10 tach==0.10.7 tomli==2.0.1; python_version < "3.11" diff --git a/pyproject.toml b/pyproject.toml index 1504c8b17b..02d301957c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ all-cuda12 = ['gt4py[cuda12,dace,formatting,jax-cuda12,performance,testing]'] # Other extras cuda11 = ['cupy-cuda11x>=12.0'] cuda12 = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.16.1', 'sympy>=1.9,<1.13'] # see https://github.com/spcl/dace/pull/1620 +dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 formatting = ['clang-format>=9.0'] gpu = ['cupy>=12.0'] jax-cpu = ['jax[cpu]>=0.4.18; python_version>="3.10"'] diff --git a/requirements-dev.txt b/requirements-dev.txt index ca7eb32487..9f95779fd5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,7 @@ contourpy==1.1.1 # via -c constraints.txt, matplotlib coverage[toml]==7.6.1 # via -c constraints.txt, -r requirements-dev.in, pytest-cov cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) +dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in debugpy==1.8.8 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython @@ -50,7 +50,7 @@ factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pyte faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, tox, virtualenv -fonttools==4.54.1 # via -c constraints.txt, matplotlib +fonttools==4.55.0 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython @@ -67,7 +67,7 @@ iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake ipython==8.12.3 # via -c constraints.txt, ipykernel jedi==0.19.2 # via -c constraints.txt, ipython -jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx +jinja2==3.1.4 # via -c constraints.txt, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema jupyter-client==8.6.3 # via -c constraints.txt, ipykernel, nbclient @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via -c constraints.txt, black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.4 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in +ruff==0.7.4 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb @@ -156,7 +156,7 @@ sphinxcontrib-qthelp==1.0.3 # via -c constraints.txt, sphinx sphinxcontrib-serializinghtml==1.1.5 # via -c constraints.txt, sphinx stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach -sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) +sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox @@ -172,7 +172,6 @@ urllib3==2.2.3 # via -c constraints.txt, requests virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -websockets==13.1 # via -c constraints.txt, dace wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index 7acd997a0d..d7326e1131 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -63,16 +63,12 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): dtype=set, allow_none=True, default=None, - optional=True, - optional_condition=lambda _: False, desc="Set of nodes that are independent of the blocking parameter.", ) dependent_nodes = dace_properties.Property( dtype=set, allow_none=True, default=None, - optional=True, - optional_condition=lambda _: False, desc="Set of nodes that are dependent on the blocking parameter.", ) From 9dbc8842a2e6e16855da030934f7aecc23f8417b Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Nov 2024 07:34:23 +0100 Subject: [PATCH 10/43] feat[next]: Enable GPU tests on GTIR DaCe backend (#1741) DaCe v1.0.0 allows to enable GPU tests on the GTIR backend. An issue was found in `test_double_use_scalar`. The dace gpu transformations have a bug that produces invalid code for SDFGs containing scalar expressions outside the field operator. A workaround is to run the simplify pass in order to bring the SDFG to a canonical form. The changes in test code (`test_execution.py`) are pure cleanup. --- .../program_processors/runners/dace_common/utility.py | 2 +- .../runners/dace_fieldview/workflow.py | 7 ++++++- tests/next_tests/definitions.py | 6 +----- .../feature_tests/ffront_tests/test_execution.py | 11 +---------- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index d678fdab7f..bc01e2abda 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -19,7 +19,7 @@ # regex to match the symbols for field shape and strides -FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+") +FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+") def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 85ae95c432..aa4fd0cd3e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -57,7 +57,12 @@ def generate_sdfg( if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) elif on_gpu: - gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=False) + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + gtx_transformations.gt_simplify(sdfg) + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index c86ba88ead..01fd18897d 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -193,11 +193,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST - + [ - # TODO(edopao): Enable when GPU codegen issues related to symbolic domain are fixed. - (ALL, XFAIL, UNSUPPORTED_MESSAGE), - ], + OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f10f195d3a..a5453151e6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -7,21 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import reduce -from gt4py.next.otf import languages, stages, workflow -from gt4py.next.otf.binding import interface import numpy as np import pytest -import diskcache -from gt4py.eve import SymbolName - import gt4py.next as gtx from gt4py.next import ( astype, broadcast, common, - constructors, errors, - field_utils, float32, float64, int32, @@ -30,8 +23,6 @@ neighbor_sum, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import gtfn -from gt4py.next.type_system import type_specifications as ts from gt4py.next import utils as gt_utils from next_tests.integration_tests import cases @@ -306,7 +297,7 @@ def test_double_use_scalar(cartesian_case): # TODO(tehrengruber): This should be a regression test on ITIR level, but tracing doesn't # work for this case. @gtx.field_operator - def testee(a: np.int32, b: np.int32, c: cases.IField) -> cases.IField: + def testee(a: int32, b: int32, c: cases.IField) -> cases.IField: tmp = a * b tmp2 = tmp * tmp # important part here is that we use the intermediate twice so that it is From 5e937363ce3427f001addecf81815894ec3b9941 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 20 Nov 2024 13:47:08 +0100 Subject: [PATCH 11/43] feat[next]: Extend the IR pass for pruning of unnecessary casts (#1728) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend the IR pass delivered in #1688 Pruning of cast expressions may appear as a `as_fieldop` expression with form `(⇑(λ(__val) → cast_(·__val, float64)))(a)`, where `a` is already a field with data type `float64` in this example. This PR adds pruning of such trivial expressions. --- src/gt4py/next/ffront/foast_to_gtir.py | 4 +-- .../ir_utils/common_pattern_matcher.py | 26 ++++++++++++++++ src/gt4py/next/iterator/ir_utils/ir_makers.py | 22 +++++++++++++ .../next/iterator/transforms/prune_casts.py | 31 +++++++++++-------- .../ffront_tests/test_foast_to_gtir.py | 28 +++++------------ .../transforms_tests/test_prune_casts.py | 19 ++++++++++++ .../dace_tests/test_gtir_to_sdfg.py | 4 +-- 7 files changed, 94 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 6cf4cc67fd..2c2971f49a 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -360,9 +360,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: if isinstance(t[0], ts.FieldType): - return im.as_fieldop( - im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) - )(expr) + return im.cast_as_fieldop(str(new_type))(expr) else: assert isinstance(t[0], ts.ScalarType) return im.call("cast_")(expr, str(new_type)) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 16a88b282a..9df091ac2a 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -10,6 +10,7 @@ from typing import TypeGuard from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -84,3 +85,28 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC def is_ref_to(node, ref: str): return isinstance(node, itir.SymRef) and node.id == ref + + +def is_identity_as_fieldop(node: itir.Expr): + """ + Match field operators implementing element-wise copy of a field argument, + that is expressions of the form `as_fieldop(stencil)(*args)` + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.as_fieldop(im.lambda_("__arg0")(im.deref("__arg0")))("a") + >>> is_identity_as_fieldop(node) + True + >>> node = im.as_fieldop("deref")("a") + >>> is_identity_as_fieldop(node) + False + """ + if not is_applied_as_fieldop(node): + return False + stencil = node.fun.args[0] # type: ignore[attr-defined] + if ( + isinstance(stencil, itir.Lambda) + and len(stencil.params) == 1 + and stencil == im.lambda_(stencil.params[0])(im.deref(stencil.params[0].id)) + ): + return True + return False diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index d7a66b8285..2864c7f727 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -497,6 +497,28 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): + """ + Promotes the function `cast_` to a field_operator. + + Args: + type_: the target type to be passed as argument to `cast_` function. + domain: the domain of the returned field. + + Returns: + A function from Fields to Field. + + Examples: + >>> str(cast_as_fieldop("float32")("a")) + '(⇑(λ(__arg0) → cast_(·__arg0, float32)))(a)' + """ + + def _impl(it: itir.Expr) -> itir.FunCall: + return op_as_fieldop(lambda v: call("cast_")(v, type_), domain)(it) + + return _impl + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py index 0720394db5..c825f68a5f 100644 --- a/src/gt4py/next/iterator/transforms/prune_casts.py +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -6,13 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py import eve from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts -class PruneCasts(PreserveLocationVisitor, NodeTranslator): +class PruneCasts(eve.NodeTranslator): """ Removes cast expressions where the argument is already in the target type. @@ -20,23 +20,28 @@ class PruneCasts(PreserveLocationVisitor, NodeTranslator): therefore it should be applied after type-inference. """ + PRESERVED_ANNEX_ATTRS = ("domain",) + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: node = self.generic_visit(node) - if not cpm.is_call_to(node, "cast_"): - return node + if cpm.is_call_to(node, "cast_"): + value, type_constructor = node.args - value, type_constructor = node.args + assert ( + value.type + and isinstance(type_constructor, ir.SymRef) + and (type_constructor.id in ir.TYPEBUILTINS) + ) + dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) - assert ( - value.type - and isinstance(type_constructor, ir.SymRef) - and (type_constructor.id in ir.TYPEBUILTINS) - ) - dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) + if value.type == dtype: + return value - if value.type == dtype: - return value + elif cpm.is_identity_as_fieldop(node): + # pruning of cast expressions may leave some trivial `as_fieldop` expressions + # with form '(⇑(λ(__arg) → ·__arg))(a)' + return node.args[0] return node diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 4a1a7cba8e..516890ea46 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -284,9 +284,7 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - "a" - ) + reference = im.cast_as_fieldop("int32")("a") assert lowered.expr == reference @@ -312,12 +310,8 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, "a") - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, "a") - ), + im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), + im.cast_as_fieldop("int32")(im.tuple_get(1, "a")), ) assert lowered_inlined.expr == reference @@ -332,9 +326,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, "a") - ), + im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), im.call("cast_")(im.tuple_get(1, "a"), "int32"), ) @@ -356,16 +348,10 @@ def foo( reference = im.make_tuple( im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, im.tuple_get(0, "a")) - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, im.tuple_get(0, "a")) - ), - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, "a") + im.cast_as_fieldop("int32")(im.tuple_get(0, im.tuple_get(0, "a"))), + im.cast_as_fieldop("int32")(im.tuple_get(1, im.tuple_get(0, "a"))), ), + im.cast_as_fieldop("int32")(im.tuple_get(1, "a")), ) assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 462eed8408..7c991fb9a8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py import next as gtx from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.prune_casts import PruneCasts @@ -21,3 +22,21 @@ def test_prune_casts_simple(): expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) assert actual == expected + + +def test_prune_casts_fieldop(): + IDim = gtx.Dimension("IDim") + x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) + y_ref = im.ref("y", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))) + testee = im.op_as_fieldop("plus")( + im.cast_as_fieldop("float64")(x_ref), + im.cast_as_fieldop("float64")(y_ref), + ) + testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + + expected = im.op_as_fieldop("plus")( + im.cast_as_fieldop("float64")(x_ref), + y_ref, + ) + actual = PruneCasts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index a94157ecd2..e0c0c3fa4e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -146,9 +146,7 @@ def test_gtir_cast(): body=[ gtir.SetAt( expr=im.op_as_fieldop("eq", domain)( - im.as_fieldop( - im.lambda_("a")(im.call("cast_")(im.deref("a"), "float32")), domain - )("x"), + im.cast_as_fieldop("float32", domain)("x"), "y", ), domain=domain, From 0a01597d0bcd5b1288ff9d42293fc1225738e977 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 21 Nov 2024 14:21:26 +0100 Subject: [PATCH 12/43] bug[next]: extract scalar value with correct dtype (#1723) credits to @egparedes for this pattern and realizing that `item()` decays to a python type. --- src/gt4py/next/embedded/nd_array_field.py | 3 ++- .../runners/dace_common/dace_backend.py | 6 +----- .../unit_tests/embedded_tests/test_nd_array_field.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 655a1137e8..9ff5feaaee 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -148,7 +148,8 @@ def as_scalar(self) -> core_defs.ScalarT: raise ValueError( f"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." ) - return self.ndarray.item() + # note: `.item()` will return a Python type, therefore we use indexing with an empty tuple + return self.asnumpy()[()] # type: ignore[return-value] # should be ensured by the 0-d check @property def codomain(self) -> type[core_defs.ScalarT]: diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index bbf45a822c..db0df7d121 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -28,11 +28,7 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: return arg if len(arg.domain.dims) == 0: # Pass zero-dimensional fields as scalars. - # We need to extract the scalar value from the 0d numpy array without changing its type. - # Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar, - # which may change its precision. To avoid this, we use here the empty tuple as index - # for 'ndarray.__getitem__()'. - return arg.asnumpy()[()] + return arg.as_scalar() # field domain offsets are not supported non_zero_offsets = [ (dim, dim_range) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 9fba633cba..063e79d92e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -264,6 +264,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte assert np.allclose(op_result.ndarray, expected_result) +def test_as_scalar(nd_array_implementation): + testee = common._field( + nd_array_implementation.asarray(42.0, dtype=np.float32), domain=common.Domain() + ) + + result = testee.as_scalar() + assert result == 42.0 + assert isinstance(result, np.float32) + + def product_nd_array_implementation_params(): for xp1 in nd_array_field._nd_array_implementations: for xp2 in nd_array_field._nd_array_implementations: From 1cb29e3ce7f24954a14054be51d375f9851d533c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 11:14:53 +0100 Subject: [PATCH 13/43] build: add devcontainer setup (#1725) Add devcontainer configuration with special customizations for VS Code. --------- Co-authored-by: Enrique Gonzalez Paredes --- .devcontainer/.vscode/launch.json | 24 +++++++++++++++ .devcontainer/Dockerfile | 5 ++++ .devcontainer/devcontainer.json | 49 +++++++++++++++++++++++++++++++ .devcontainer/setup.sh | 10 +++++++ .gitignore | 2 +- 5 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 .devcontainer/.vscode/launch.json create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100755 .devcontainer/setup.sh diff --git a/.devcontainer/.vscode/launch.json b/.devcontainer/.vscode/launch.json new file mode 100644 index 0000000000..f682b56388 --- /dev/null +++ b/.devcontainer/.vscode/launch.json @@ -0,0 +1,24 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File (just my code)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "Python: Current File (all)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000..414f2d0292 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,5 @@ +FROM mcr.microsoft.com/devcontainers/python:1-3.10-bookworm +RUN apt-get update \ + && export DEBIAN_FRONTEND=noninteractive && apt-get install -y libboost-dev \ + && apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/* +RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/bin" sh diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000..7dc4b2f08c --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,49 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "build": { + "dockerfile": "Dockerfile" + }, + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "bash .devcontainer/setup.sh", + + "containerEnv": { + "PRE_COMMIT_HOME": "/workspaces/gt4py/.caches/pre-commit" + }, + + // Configure tool-specific properties. + "customizations": { + // Configure properties specific to VS Code. + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "python.formatting.provider": "ruff", + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "/workspaces/gt4py/.venv/bin/python", + "files.insertFinalNewline": true, + "python.terminal.activateEnvironment": true, + "cmake.ignoreCMakeListsMissing": true + }, + "extensions": [ + "charliermarsh.ruff", + "donjayamanne.githistory", + "github.vscode-github-actions", + "lextudio.restructuredtext", + "ms-python.python", + "ms-vsliveshare.vsliveshare", + "swyddfa.esbonio" + ] + } + } + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.devcontainer/setup.sh b/.devcontainer/setup.sh new file mode 100755 index 0000000000..d23dda9dea --- /dev/null +++ b/.devcontainer/setup.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +ln -sfn /workspaces/gt4py/.devcontainer/.vscode /workspaces/gt4py/.vscode +uv venv .venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +uv pip install -e . +uv pip install -i https://test.pypi.org/simple/ atlas4py +pre-commit install --install-hooks +deactivate diff --git a/.gitignore b/.gitignore index 5792b8a9b7..b1c8ed26e9 100644 --- a/.gitignore +++ b/.gitignore @@ -159,5 +159,5 @@ venv.bak/ ### Others ### .obsidian - coverage.json +.caches From d7f55522beacfc77c12964f6bbb1962899d8821d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 14:22:24 +0100 Subject: [PATCH 14/43] feat[next]: remove NeighborTableOffsetProvider, use gtx.as_connectivity (#1729) User-facing change: use `gtx.as_connectivity` to create a connectivity/neighbor table instead of `NeighborTableOffsetProvider` which is deprecated (and the backward-compatible mechanism broken for some use-cases). The internal concepts of `Connectivity` and `NeighborTable` are updated. `ConnectivityType` is introduced which contains the compile-time info of a `Connectivity`. See ADR 19. Additionally, the compile-time info is used (instead of the run-time connectivities) in many places of the toolchain when possible. --- .gitpod/.vscode/launch.json | 13 +- .../0008-Mapping_Domain_to_Cpp-Backend.md | 2 +- docs/development/ADRs/0019-Connectivities.md | 55 +++++ docs/user/next/QuickstartGuide.md | 6 +- .../exercises/2_divergence_exercise.ipynb | 4 +- .../2_divergence_exercise_solution.ipynb | 4 +- .../exercises/3_gradient_exercise.ipynb | 4 +- .../3_gradient_exercise_solution.ipynb | 4 +- .../workshop/exercises/4_curl_exercise.ipynb | 4 +- .../exercises/4_curl_exercise_solution.ipynb | 4 +- .../exercises/5_vector_laplace_exercise.ipynb | 10 +- .../5_vector_laplace_exercise_solution.ipynb | 10 +- .../8_diffusion_exercise_solution.ipynb | 8 +- docs/user/next/workshop/slides/slides_2.ipynb | 10 +- src/gt4py/_core/definitions.py | 10 +- src/gt4py/next/__init__.py | 6 +- src/gt4py/next/common.py | 170 ++++++++++---- src/gt4py/next/constructors.py | 24 +- src/gt4py/next/embedded/nd_array_field.py | 35 ++- src/gt4py/next/ffront/decorator.py | 47 ++-- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 30 +-- src/gt4py/next/iterator/embedded.py | 215 +++++++++++------- .../next/iterator/ir_utils/domain_utils.py | 26 +-- src/gt4py/next/iterator/runtime.py | 10 +- .../iterator/transforms/collapse_tuple.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../iterator/transforms/fuse_as_fieldop.py | 9 +- .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/inline_scalar.py | 4 +- .../next/iterator/transforms/pass_manager.py | 29 ++- .../transforms/pass_manager_legacy.py | 14 +- .../next/iterator/transforms/unroll_reduce.py | 28 +-- .../next/iterator/type_system/inference.py | 34 +-- .../iterator/type_system/type_synthesizer.py | 48 ++-- src/gt4py/next/otf/arguments.py | 54 +---- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 76 +------ .../codegens/gtfn/gtfn_module.py | 47 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 31 +-- .../runners/dace_common/dace_backend.py | 21 +- .../runners/dace_common/utility.py | 15 +- .../runners/dace_fieldview/gtir_dataflow.py | 75 +++--- .../runners/dace_fieldview/gtir_sdfg.py | 33 ++- .../runners/dace_fieldview/workflow.py | 6 +- .../runners/dace_iterator/__init__.py | 53 +++-- .../runners/dace_iterator/itir_to_sdfg.py | 45 ++-- .../runners/dace_iterator/itir_to_tasklet.py | 97 ++++---- .../runners/dace_iterator/utility.py | 10 +- .../runners/dace_iterator/workflow.py | 6 +- .../next/program_processors/runners/gtfn.py | 16 +- .../program_processors/runners/roundtrip.py | 16 +- .../next/type_system/type_specifications.py | 1 + .../feature_tests/dace/test_orchestration.py | 86 ++++--- .../ffront_tests/ffront_test_utils.py | 91 +++++--- .../ffront_tests/test_execution.py | 36 +-- .../ffront_tests/test_external_local_field.py | 12 +- .../ffront_tests/test_gt4py_builtins.py | 18 +- .../test_temporaries_with_sizes.py | 2 +- .../iterator_tests/test_builtins.py | 40 +--- .../test_strided_offset_provider.py | 9 +- .../ffront_tests/test_ffront_fvm_nabla.py | 64 +++--- .../multi_feature_tests/fvm_nabla_setup.py | 56 +++-- .../iterator_tests/test_fvm_nabla.py | 114 ++++------ .../test_with_toy_connectivity.py | 38 ++-- tests/next_tests/toy_connectivity.py | 7 + tests/next_tests/unit_tests/conftest.py | 25 +- .../embedded_tests/test_nd_array_field.py | 15 +- .../test_embedded_field_with_list.py | 10 +- .../iterator_tests/test_runtime_domain.py | 10 +- .../iterator_tests/test_type_inference.py | 34 +-- .../transforms_tests/test_cse.py | 14 +- .../transforms_tests/test_domain_inference.py | 13 +- .../transforms_tests/test_fuse_as_fieldop.py | 13 +- .../transforms_tests/test_global_tmps.py | 8 +- .../transforms_tests/test_prune_casts.py | 6 +- .../transforms_tests/test_unroll_reduce.py | 69 ++++-- .../gtfn_tests/test_itir_to_gtfn_ir.py | 4 +- .../runners_tests/dace_tests/test_dace.py | 24 +- .../dace_tests/test_gtir_to_sdfg.py | 134 ++++++----- .../unit_tests/test_constructors.py | 14 +- 80 files changed, 1293 insertions(+), 1170 deletions(-) create mode 100644 docs/development/ADRs/0019-Connectivities.md diff --git a/.gitpod/.vscode/launch.json b/.gitpod/.vscode/launch.json index f682b56388..b25a182648 100644 --- a/.gitpod/.vscode/launch.json +++ b/.gitpod/.vscode/launch.json @@ -6,7 +6,7 @@ "configurations": [ { "name": "Python: Current File (just my code)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", @@ -14,11 +14,20 @@ }, { "name": "Python: Current File (all)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", "justMyCode": false + }, + { + "name": "Python: Debug Tests", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": true } ] } diff --git a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md index a1ee8575d2..1ce83431ee 100644 --- a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md +++ b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md @@ -20,7 +20,7 @@ The Python embedded execution for Iterator IR keeps track of the current locatio ### Python side -On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` (aka `NeighborTableOffsetProvider` in the current implementation) describes the mapping between location types. +On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` describes the mapping between location types. ### C++ side diff --git a/docs/development/ADRs/0019-Connectivities.md b/docs/development/ADRs/0019-Connectivities.md new file mode 100644 index 0000000000..76e85e49a6 --- /dev/null +++ b/docs/development/ADRs/0019-Connectivities.md @@ -0,0 +1,55 @@ +--- +tags: [] +--- + +# [Connectivities] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2024-11-08 +- **Updated**: 2024-11-08 + +The representation of Connectivities (neighbor tables, `NeighborTableOffsetProvider`) and their identifier (offset tag, `FieldOffset`, etc.) was extended and modified based on the needs of different parts of the toolchain. Here we outline the ideas for consolidating the different closely-related concepts. + +## History + +In the early days of Iterator IR (ITIR), an `offset` was a literal in the IR. Its meaning was only provided at execution time by a mapping from `offset` tag to an entity that we labelled `OffsetProvider`. We had mainly 2 kinds of `OffsetProvider`: a `Dimension` representing a Cartesian shift and a `NeighborTableOffsetProvider` for unstructured shifts. Since the type of `offset` needs to be known for compilation (strided for Cartesian, lookup-table for unstructured), this prevents a clean interface for ahead-of-time compilation. +For the frontend type-checking we later introduce a `FieldOffset` which contained type information of the mapped dimensions. +For (field-view) embedded we introduced a `ConnectivityField` (now `Connectivity`) which could be generated from the OffsetProvider information. + +These different concepts had overlap but were not 1-to-1 replacements. + +## Decision + +We update and introduce the following concepts + +### Conceptual definitions + +**Connectivity** is a mapping from index (or product of indices) to index. It covers 1-to-1 mappings, e.g. Cartesian shifts, NeighborTables (2D mappings) and dynamic Cartesian shifts. + +**NeighborConnectivity** is a 2D mapping of the N neighbors of a Location A to a Location B. + +**NeighborTable** is a _NeighborConnectivity_ backed by a buffer. + +**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation. + +### Full definitions + +See `next.common` module + +Note: Currently, the compiled backends supports only `NeighborConnectivity`s that are `NeighborTable`s. We do not yet encode this in the type and postpone discussion to the point where we support alternative implementations (e.g. `StridedNeighborConnectivity`). + +## Which parts of the toolchain use which concept? + +### Embedded + +Embedded execution of field-view supports any kind of `Connectivity`. +Embedded execution of iterator (local) view supports only `NeighborConnectivity`s. + +### IR transformations and compiled backends + +All transformations and code-generation should use `ConnectivityType`, not the `Connectivity` which contains the runtime mapping. + +Note, currently the `global_tmps` pass uses runtime information, therefore this is not strictly enforced. + +The only supported `Connectivity`s in compiled backends (currently) are `NeighborTable`s. diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 81604c7770..2cb6647519 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -155,8 +155,6 @@ This section approaches the pseudo-laplacian by introducing the required APIs pr - [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) - [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) -+++ - #### Defining the mesh and its connectivities The examples related to unstructured meshes use the mesh below. The edges (in blue) and the cells (in red) are numbered with zero-based indices. @@ -237,7 +235,7 @@ E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) Note that the field offset does not contain the actual connectivity table, that's provided through an _offset provider_: ```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2) +E2C_offset_provider = gtx.as_connectivity([EdgeDim, E2CDim], codomain=CellDim, data=edge_to_cell_table, skip_value=-1) ``` The field operator `nearest_cell_to_edge` below shows an example of applying this transform. There is a little twist though: the subscript in `E2C[0]` means that only the value of the first connected cell is taken, the second (if exists) is ignored. @@ -385,7 +383,7 @@ As explained in the section outline, the pseudo-laplacian needs the cell-to-edge C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) C2E = gtx.FieldOffset("C2E", source=EdgeDim, target=(CellDim, C2EDim)) -C2E_offset_provider = gtx.NeighborTableOffsetProvider(cell_to_edge_table, CellDim, EdgeDim, 3) +C2E_offset_provider = gtx.as_connectivity([CellDim, C2EDim], codomain=EdgeDim, data=cell_to_edge_table, skip_value=-1) ``` **Weights of edge differences:** diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb index 50349e52b0..b0a1980d0f 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -113,7 +113,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb index 6baac2b8c0..573ee6a44e 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -118,7 +118,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb index c8914120d3..2b422b1823 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -110,7 +110,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb index 5e940a4b71..85044b989f 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -123,7 +123,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb index 4a6b37baf7..dc321f1bdd 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -134,7 +134,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb index 065cf02de7..251fe8239a 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -139,7 +139,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb index 832375a86b..30f568de6f 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -272,10 +272,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb index be846d199d..eaeb8c7b02 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -293,10 +293,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb index d4bcdb33d5..b278cee26d 100644 --- a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -156,10 +156,8 @@ " dt,\n", " )\n", "\n", - " e2c2v_connectivity = gtx.NeighborTableOffsetProvider(\n", - " e2c2v_table, E, V, 4, has_skip_values=False\n", - " )\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " e2c2v_connectivity = gtx.as_connectivity([E, E2C2VDim], codomain=V, data=e2c2v_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " diffusion_step(\n", " u,\n", diff --git a/docs/user/next/workshop/slides/slides_2.ipynb b/docs/user/next/workshop/slides/slides_2.ipynb index 1e8925087f..c6967df4b2 100644 --- a/docs/user/next/workshop/slides/slides_2.ipynb +++ b/docs/user/next/workshop/slides/slides_2.ipynb @@ -281,17 +281,19 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6d30a5e1", "metadata": {}, "outputs": [], "source": [ - "E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, Edge, Cell, 2)" + "E2C_offset_provider = gtx.as_connectivity(\n", + " [Edge, E2CDim], codomain=Cell, data=e2c_table, skip_value=-1\n", + ")" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "d62f6c98", "metadata": {}, "outputs": [ @@ -311,7 +313,7 @@ " return cell_field(E2C[0]) # 0th index to isolate edge dimension\n", "\n", "\n", - "@gtx.program # uses skip_values, therefore we cannot use embedded\n", + "@gtx.program\n", "def run_nearest_cell_to_edge(\n", " cell_field: gtx.Field[Dims[Cell], float64], edge_field: gtx.Field[Dims[Edge], float64]\n", "):\n", diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 9d07b2eb79..8f62788b8f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -439,13 +439,21 @@ def ndim(self) -> int: ... @property def shape(self) -> tuple[int, ...]: ... + @property + def strides(self) -> tuple[int, ...]: ... + @property def dtype(self) -> Any: ... + @property + def itemsize(self) -> int: ... + def item(self) -> Any: ... def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... + def any(self) -> bool: ... + def __getitem__(self, item: Any) -> NDArrayObject: ... def __abs__(self) -> NDArrayObject: ... @@ -496,4 +504,4 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 80bb276c70..4fa9215706 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -20,6 +20,7 @@ from . import common, ffront, iterator, program_processors from .common import ( + Connectivity, Dimension, DimensionKind, Dims, @@ -39,8 +40,7 @@ from .ffront.fbuiltins import * # noqa: F403 [undefined-local-with-import-star] explicitly reexport all from fbuiltins.__all__ from .ffront.fbuiltins import FieldOffset from .iterator.embedded import ( - NeighborTableOffsetProvider, - StridedNeighborOffsetProvider, + NeighborTableOffsetProvider, # TODO(havogt): deprecated index_field, np_as_located_field, ) @@ -61,6 +61,7 @@ "Dimension", "DimensionKind", "Field", + "Connectivity", "GridType", "domain", "Domain", @@ -75,7 +76,6 @@ "as_connectivity", # from iterator "NeighborTableOffsetProvider", - "StridedNeighborOffsetProvider", "index_field", "np_as_located_field", # from ffront diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4aa0dd03aa..9b2870e1c0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -18,7 +18,6 @@ from collections.abc import Mapping, Sequence import numpy as np -import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import utils @@ -95,7 +94,7 @@ def __str__(self) -> str: def __call__(self, val: int) -> NamedIndex: return NamedIndex(self, val) - def __add__(self, offset: int) -> ConnectivityField: + def __add__(self, offset: int) -> Connectivity: # TODO(sf-n): just to avoid circular import. Move or refactor the FieldOffset to avoid this. from gt4py.next.ffront import fbuiltins @@ -104,7 +103,7 @@ def __add__(self, offset: int) -> ConnectivityField: dimension_to_implicit_offset(self.value), source=self, target=(self,) )[offset] - def __sub__(self, offset: int) -> ConnectivityField: + def __sub__(self, offset: int) -> Connectivity: return self + (-offset) @@ -678,6 +677,9 @@ def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... + # TODO(havogt) + # This property is wrong, because for a function field we would not know to which NDArrayObject we want to convert + # at the very least, we need to take an allocator and rename this to `as_ndarray`. @property def ndarray(self) -> core_defs.NDArrayObject: ... @@ -688,7 +690,7 @@ def __str__(self) -> str: def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def premap(self, index_field: Connectivity | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod def restrict(self, item: AnyIndexSpec) -> Self: ... @@ -700,8 +702,8 @@ def as_scalar(self) -> core_defs.ScalarT: ... @abc.abstractmethod def __call__( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, ) -> Field: ... @abc.abstractmethod @@ -811,12 +813,64 @@ def remapping(cls) -> ConnectivityKind: return cls.ALTER_DIMS | cls.ALTER_STRUCT +@dataclasses.dataclass(frozen=True) +class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import + domain: tuple[Dimension, ...] + codomain: Dimension + skip_value: Optional[core_defs.IntegralScalar] + dtype: core_defs.DType + + @property + def has_skip_values(self) -> bool: + return self.skip_value is not None + + +@dataclasses.dataclass(frozen=True) +class NeighborConnectivityType(ConnectivityType): + # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain + max_neighbors: int + + @property + def source_dim(self) -> Dimension: + return self.domain[0] + + @property + def neighbor_dim(self) -> Dimension: + return self.domain[1] + + @runtime_checkable # type: ignore[misc] # DimT should be covariant, but then it breaks in other places -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): +class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: ... + def codomain(self) -> DimT: + """ + The `codomain` is the set of all indices in a certain `Dimension`. + + We use the `Dimension` itself to describe the (infinite) set of all indices. + + Note: + We could restrict the infinite codomain to only the indices that are actually contained in the mapping. + Currently, this would just complicate implementation as we do not use this information. + """ + + def __gt_type__(self) -> ConnectivityType: + if is_neighbor_connectivity(self): + return NeighborConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) + else: + return ConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + ) @property def kind(self) -> ConnectivityKind: @@ -831,61 +885,61 @@ def skip_value(self) -> Optional[core_defs.IntegralScalar]: ... # Operators def __abs__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") # Utility function to construct a `Field` from different buffer representations. @@ -911,38 +965,58 @@ def _connectivity( domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, -) -> ConnectivityField: +) -> Connectivity: raise NotImplementedError -@runtime_checkable -class Connectivity(Protocol): - max_neighbors: int - has_skip_values: bool - origin_axis: Dimension - neighbor_axis: Dimension - index_type: type[int] | type[np.int32] | type[np.int64] +class NeighborConnectivity(Connectivity, Protocol): + # TODO(havogt): work towards encoding this properly in the type + def __gt_type__(self) -> NeighborConnectivityType: ... + - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - """Return neighbor index.""" +def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: + if not isinstance(obj, Connectivity): + return False + domain_dims = obj.domain.dims + return ( + len(domain_dims) == 2 + and domain_dims[0].kind is DimensionKind.HORIZONTAL + and domain_dims[1].kind is DimensionKind.LOCAL + ) -@runtime_checkable -class NeighborTable(Connectivity, Protocol): - table: npt.NDArray +class NeighborTable( + NeighborConnectivity, Protocol +): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) + @property + def ndarray(self) -> core_defs.NDArrayObject: + # Note that this property is currently already there from inheriting from `Field`, + # however this seems wrong, therefore we explicitly introduce it here (or it should come + # implicitly from the `NdArrayConnectivityField` protocol). + ... -OffsetProviderElem: TypeAlias = Dimension | Connectivity +def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: + return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") + + +OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity +OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] +OffsetProviderType: TypeAlias = Mapping[Tag, OffsetProviderTypeElem] + + +def offset_provider_to_type(offset_provider: OffsetProvider) -> OffsetProviderType: + return { + k: v.__gt_type__() if isinstance(v, Connectivity) else v for k, v in offset_provider.items() + } DomainDimT = TypeVar("DomainDimT", bound="Dimension") @dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(ConnectivityField[Dims[DomainDimT], DimT]): +class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): domain_dim: DomainDimT codomain: DimT offset: int = 0 @@ -981,7 +1055,7 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: return core_defs.Int32DType() # type: ignore[return-value] # This is a workaround to make this class concrete, since `codomain` is an - # abstract property of the `ConnectivityField` Protocol. + # abstract property of the `Connectivity` Protocol. if not TYPE_CHECKING: @functools.cached_property @@ -1024,9 +1098,9 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa def premap( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, - ) -> ConnectivityField: + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, + ) -> Connectivity: raise NotImplementedError() __call__ = premap diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index dd52559e85..7b39511674 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -290,22 +290,24 @@ def as_connectivity( *, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - skip_value: Optional[core_defs.IntegralScalar] = None, + skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, # TODO: copy=False -) -> common.ConnectivityField: +) -> common.Connectivity: """ - Construct a connectivity field from the given domain, codomain, and data. + Construct a `Connectivity` from the given domain, codomain, and data. Arguments: - domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + domain: The domain of the connectivity. It can be either a `common.DomainLike` object or a sequence of `common.Dimension` objects. - codomain: The codomain dimension of the connectivity field. + codomain: The codomain dimension of the connectivity. data: The data used to construct the connectivity field. - dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. - allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + dtype: The data type of the connectivity. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity. If not provided, a default allocator will be used. - device: The device on which the connectivity field will be allocated. If not provided, the default + device: The device on which the connectivity will be allocated. If not provided, the default device will be used. + skip_value: The value that signals missing entries in the neighbor table. Defaults to the default + skip value if it is found in data, otherwise to `None` (= no skip value). Returns: The constructed connectivity field. @@ -313,9 +315,15 @@ def as_connectivity( Raises: ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ + if skip_value is eve.NOTHING: + skip_value = ( + common._DEFAULT_SKIP_VALUE if (data == common._DEFAULT_SKIP_VALUE).any() else None + ) + assert ( skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE ) # TODO(havogt): not yet configurable + skip_value = cast(Optional[core_defs.IntegralScalar], skip_value) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9ff5feaaee..e15fb4266a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -36,7 +36,6 @@ exceptions as embedded_exceptions, ) from gt4py.next.ffront import experimental, fbuiltins -from gt4py.next.iterator import embedded as itir_embedded try: @@ -189,10 +188,10 @@ def from_array( def premap( self: NdArrayField, - *connectivities: common.ConnectivityField | fbuiltins.FieldOffset, + *connectivities: common.Connectivity | fbuiltins.FieldOffset, ) -> NdArrayField: """ - Rearrange the field content using the provided connectivity fields as index mappings. + Rearrange the field content using the provided connectivities (index mappings). This operation is conceptually equivalent to a regular composition of mappings `f∘c`, being `c` the `connectivity` argument and `f` the `self` data field. @@ -206,7 +205,7 @@ def premap( argument used in the right hand side of the operator should therefore have the same product of dimensions `c: S × T → A × B`. Such a mapping can also be expressed as a pair of mappings `c1: S × T → A` and `c2: S × T → B`, and this - is actually the only supported form in GT4Py because `ConnectivityField` instances + is actually the only supported form in GT4Py because `Connectivity` instances can only deal with a single dimension in its codomain. This approach makes connectivities reusable for any combination of dimensions in a field domain and matches the NumPy advanced indexing API, which basically is a @@ -261,15 +260,15 @@ def premap( """ # noqa: RUF002 # TODO(egparedes): move docstring to the `premap` builtin function when it exists - conn_fields: list[common.ConnectivityField] = [] + conn_fields: list[common.Connectivity] = [] codomains_counter: collections.Counter[common.Dimension] = collections.Counter() for connectivity in connectivities: - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField - if not isinstance(connectivity, common.ConnectivityField): + # For neighbor reductions, a FieldOffset is passed instead of an actual Connectivity + if not isinstance(connectivity, common.Connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() - assert isinstance(connectivity, common.ConnectivityField) + assert isinstance(connectivity, common.Connectivity) # Current implementation relies on skip_value == -1: # if we assume the indexed array has at least one element, @@ -318,8 +317,8 @@ def premap( def __call__( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: return functools.reduce( lambda field, current_index_field: field.premap(current_index_field), @@ -460,7 +459,7 @@ def _dace_descriptor(self) -> Any: @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ - common.ConnectivityField[common.DimsT, common.DimT], + common.Connectivity[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT @@ -579,7 +578,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: __getitem__ = restrict -def _domain_premap(data: NdArrayField, *connectivities: common.ConnectivityField) -> NdArrayField: +def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> NdArrayField: """`premap` implementation transforming only the field domain not the data (i.e. translation and relocation).""" new_domain = data.domain for connectivity in connectivities: @@ -668,7 +667,7 @@ def _reshuffling_premap( ) -def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField) -> NdArrayField: +def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> NdArrayField: new_dims = {*connectivity.domain.dims} - {connectivity.codomain} if repeated_dims := (new_dims & {*data.domain.dims}): raise ValueError(f"Remapped field will contain repeated dimensions '{repeated_dims}'.") @@ -693,7 +692,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField if restricted_connectivity_domain != connectivity.domain else connectivity ) - assert isinstance(restricted_connectivity, common.ConnectivityField) + assert isinstance(restricted_connectivity, common.Connectivity) # 2- then compute the index array new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start @@ -971,7 +970,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] def _make_reduction( @@ -996,15 +995,15 @@ def _builtin_op( offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) + assert common.is_neighbor_table(offset_definition) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( - slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis + slice(None) if d in [axis, offset_definition.domain.dims[0]] else xp.newaxis for d in field.domain.dims ) masked_array = xp.where( - xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, + xp.asarray(offset_definition.ndarray[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index dc2421e1d2..9ce07d01bb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -30,7 +30,6 @@ embedded as next_embedded, errors, ) -from gt4py.next.common import Connectivity, Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, @@ -82,15 +81,15 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[dict[str, Connectivity]] + connectivities: Optional[common.OffsetProviderType] = None @classmethod def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend], - grid_type: Optional[GridType] = None, - connectivities: Optional[dict[str, Connectivity]] = None, + grid_type: Optional[common.GridType] = None, + connectivities: Optional[common.OffsetProviderType] = None, ) -> Program: program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) return cls(definition_stage=program_def, backend=backend, connectivities=connectivities) @@ -140,10 +139,10 @@ def _frontend_transforms(self) -> next_backend.Transforms: def with_backend(self, backend: next_backend.Backend) -> Program: return dataclasses.replace(self, backend=backend) - def with_connectivities(self, connectivities: dict[str, Connectivity]) -> Program: + def with_connectivities(self, connectivities: common.OffsetProviderType) -> Program: return dataclasses.replace(self, connectivities=connectivities) - def with_grid_type(self, grid_type: GridType) -> Program: + def with_grid_type(self, grid_type: common.GridType) -> Program: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -199,7 +198,7 @@ def itir(self) -> itir.FencilDefinition: return self._frontend_transforms.past_to_itir(no_args_past).data @functools.cached_property - def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]: + def _implicit_offset_provider(self) -> dict[str, common.Dimension]: """ Add all implicit offset providers. @@ -226,9 +225,7 @@ def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderEle ) return implicit_offset_provider - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: offset_provider = offset_provider | self._implicit_offset_provider if self.backend is None: warnings.warn( @@ -287,19 +284,17 @@ def definition(self) -> str: def with_backend(self, backend: next_backend.Backend) -> FrozenProgram: return self.__class__(program=self.program, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FrozenProgram: + def with_grid_type(self, grid_type: common.GridType) -> FrozenProgram: return self.__class__( program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend ) def jit( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any ) -> stages.CompiledProgram: return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs) - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs) if not self._compiled_program: @@ -328,7 +323,7 @@ class ProgramFromPast(Program): past_stage: ffront_stages.PastProgramDefinition - def __call__(self, *args: Any, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: if self.backend is None: raise NotImplementedError( "Programs created from a PAST node (without a function definition) can not be executed in embedded mode" @@ -350,7 +345,7 @@ def __post_init__(self): class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): + def __call__(self, *args, offset_provider: common.OffsetProvider, **kwargs): type_ = self.past_stage.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -436,7 +431,7 @@ def program( *, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) backend: next_backend.Backend | eve.NOTHING = eve.NOTHING, - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -506,7 +501,7 @@ def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend.Backend], - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, @@ -557,7 +552,7 @@ def __gt_type__(self) -> ts.CallableType: def with_backend(self, backend: next_backend.Backend) -> FieldOperator: return dataclasses.replace(self, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FieldOperator: + def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -688,33 +683,33 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast. def scan_operator( definition: types.FunctionType, *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> FieldOperator[foast.ScanOperator]: ... @typing.overload def scan_operator( *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( definition: Optional[types.FunctionType] = None, *, - axis: Dimension, + axis: common.Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, backend=eve.NOTHING, - grid_type: GridType = None, + grid_type: common.GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 8a94c20832..bd22aebe57 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -14,7 +14,7 @@ @BuiltInFunction -def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.ConnectivityField: +def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d932431b51..b60fa63f95 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -16,7 +16,6 @@ import numpy as np from numpy import float32, float64, int32, int64 -import gt4py.next as gtx from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS @@ -55,7 +54,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.DimensionType elif t is FieldOffset: return ts.OffsetType - elif t is common.ConnectivityField: + elif t is common.Connectivity: return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType @@ -321,7 +320,7 @@ def __post_init__(self) -> None: def __gt_type__(self) -> ts.OffsetType: return ts.OffsetType(source=self.source, target=self.target) - def __getitem__(self, offset: int) -> common.ConnectivityField: + def __getitem__(self, offset: int) -> common.Connectivity: """Serve as a connectivity factory.""" from gt4py.next import embedded # avoid circular import @@ -330,22 +329,19 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: assert current_offset_provider is not None offset_definition = current_offset_provider[self.value] - connectivity: common.ConnectivityField + connectivity: common.Connectivity if isinstance(offset_definition, common.Dimension): connectivity = common.CartesianConnectivity(offset_definition, offset) - elif isinstance( - offset_definition, (gtx.NeighborTableOffsetProvider, common.ConnectivityField) - ): - unrestricted_connectivity = self.as_connectivity_field() - assert unrestricted_connectivity.domain.ndim > 1 + elif isinstance(offset_definition, common.Connectivity): + assert common.is_neighbor_connectivity(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) - connectivity = unrestricted_connectivity[named_index] + connectivity = offset_definition[named_index] else: raise NotImplementedError() return connectivity - def as_connectivity_field(self) -> common.ConnectivityField: + def as_connectivity_field(self) -> common.Connectivity: """Convert to connectivity field using the offset providers in current embedded execution context.""" from gt4py.next import embedded # avoid circular import @@ -356,18 +352,8 @@ def as_connectivity_field(self) -> common.ConnectivityField: cache_key = id(offset_definition) if (connectivity := self._cache.get(cache_key, None)) is None: - if isinstance(offset_definition, common.ConnectivityField): + if isinstance(offset_definition, common.Connectivity): connectivity = offset_definition - elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): - connectivity = gtx.as_connectivity( - domain=self.target, - codomain=self.source, - data=offset_definition.table, - dtype=offset_definition.index_type, - skip_value=( - common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None - ), - ) else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6221c95522..3c63ffef30 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -93,77 +93,113 @@ class SparseTag(Tag): ... -class NeighborTableOffsetProvider: +@xtyping.deprecated("Use a 'Connectivity' instead.") +def NeighborTableOffsetProvider( + table: core_defs.NDArrayObject, + origin_axis: common.Dimension, + neighbor_axis: common.Dimension, + max_neighbors: int, + has_skip_values=True, +) -> common.Connectivity: + return common._connectivity( + table, + codomain=neighbor_axis, + domain={ + origin_axis: table.shape[0], + common.Dimension( + value="_DummyLocalDim", kind=common.DimensionKind.LOCAL + ): max_neighbors, + }, + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + ) + + +# TODO(havogt): complete implementation and make available for fieldview embedded +@dataclasses.dataclass(frozen=True) +class StridedConnectivityField(common.Connectivity): + domain_dims: tuple[common.Dimension, common.Dimension] + codomain_dim: common.Dimension + _max_neighbors: int + def __init__( self, - table: core_defs.NDArrayObject, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, + domain_dims: Sequence[common.Dimension], + codomain_dim: common.Dimension, max_neighbors: int, - has_skip_values=True, - ) -> None: - self.table = table - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - assert not hasattr(table, "shape") or table.shape[1] == max_neighbors - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = table.dtype - - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - res = self.table[(primary, neighbor_idx)] - assert common.is_int_index(res) - return res + ): + object.__setattr__(self, "domain_dims", tuple(domain_dims)) + object.__setattr__(self, "codomain_dim", codomain_dim) + object.__setattr__(self, "_max_neighbors", max_neighbors) - if dace: - # Extension of NeighborTableOffsetProvider adding SDFGConvertible support in GT4Py Programs - def _dace_data_ptr(self) -> int: - obj = self.table - if dace.dtypes.is_array(obj): - if hasattr(obj, "__array_interface__"): - return obj.__array_interface__["data"][0] - if hasattr(obj, "__cuda_array_interface__"): - return obj.__cuda_array_interface__["data"][0] - raise ValueError("Unsupported data container.") - - def _dace_descriptor(self) -> dace.data.Data: - return dace.data.create_datadescriptor(self.table) - else: + @property + def __gt_origin__(self) -> xtyping.Never: + raise NotImplementedError + + def __gt_type__(self) -> common.NeighborConnectivityType: + return common.NeighborConnectivityType( + domain=self.domain_dims, + codomain=self.codomain_dim, + max_neighbors=self._max_neighbors, + skip_value=self.skip_value, + dtype=self.dtype, + ) - def _dace_data_ptr(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "data_ptr is only supported when the 'dace' module is available." - ) + @property + def domain(self) -> common.Domain: + return common.Domain( + dims=self.domain_dims, + ranges=(common.UnitRange.infinite(), common.unit_range(self._max_neighbors)), + ) - def _dace_descriptor(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "__descriptor__ is only supported when the 'dace' module is available." - ) + @property + def codomain(self) -> common.Dimension: + return self.codomain_dim - data_ptr = _dace_data_ptr - __descriptor__ = _dace_descriptor + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + @property + def ndarray(self) -> core_defs.NDArrayObject: + raise NotImplementedError -class StridedNeighborOffsetProvider: - def __init__( + def asnumpy(self) -> np.ndarray: + raise NotImplementedError + + def premap(self, index_field: common.Connectivity | fbuiltins.FieldOffset) -> common.Field: + raise NotImplementedError + + def restrict( # type: ignore[override] self, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, - max_neighbors: int, - has_skip_values=True, - ) -> None: - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = int + item: common.AnyIndexSpec, + ) -> common.Field: + if not isinstance(item, tuple) or (isinstance(item, tuple) and not len(item) == 2): + raise NotImplementedError() # TODO(havogt): add proper slicing + index = item[0] * self._max_neighbors + item[1] # type: ignore[operator, call-overload] + return ConstantField(index) + + def as_scalar(self) -> xtyping.Never: + raise NotImplementedError() + + def __call__( + self, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, + ) -> common.Field: + raise NotImplementedError() - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - return primary * self.max_neighbors + neighbor_idx + __getitem__ = restrict # type: ignore[assignment] + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + raise NotImplementedError + + @property + def skip_value( + self, + ) -> None: + return None # Offsets @@ -597,10 +633,11 @@ def execute_shift( new_entry[i] = 0 else: offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_neighbor_connectivity(offset_implementation) + source_dim = offset_implementation.__gt_type__().source_dim + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: @@ -620,22 +657,22 @@ def execute_shift( else: raise AssertionError() return new_pos - else: - assert isinstance(offset_implementation, common.Connectivity) - assert offset_implementation.origin_axis.value in pos + elif common.is_neighbor_connectivity(offset_implementation): + source_dim = offset_implementation.__gt_type__().source_dim + assert source_dim.value in pos new_pos = pos.copy() - new_pos.pop(offset_implementation.origin_axis.value) - cur_index = pos[offset_implementation.origin_axis.value] + new_pos.pop(source_dim.value) + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: return None else: - new_index = offset_implementation.mapped_index(cur_index, index) + new_index = offset_implementation[cur_index, index].as_scalar() assert new_index is not None - new_pos[offset_implementation.neighbor_axis.value] = int(new_index) + new_pos[offset_implementation.codomain.value] = int(new_index) return new_pos @@ -1196,8 +1233,8 @@ def as_scalar(self) -> core_defs.IntegralScalar: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1322,8 +1359,8 @@ def asnumpy(self) -> np.ndarray: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1428,10 +1465,12 @@ def __gt_type__(self) -> itir_ts.ListType: assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( - element_type=element_type, - offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL), - ) + offset_provider = embedded_context.offset_provider.get() + assert offset_provider is not None + connectivity = offset_provider[offset_tag] + assert common.is_neighbor_connectivity(connectivity) + local_dim = connectivity.__gt_type__().neighbor_dim + return itir_ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1457,11 +1496,11 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[offset_str] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if (shifted := it.shift(offset_str, i)).can_deref() ), offset=offset, @@ -1533,11 +1572,11 @@ def deref(self) -> Any: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if ( shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) ).can_deref() @@ -1671,9 +1710,9 @@ def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} -def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: +def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: if isinstance(domain, runtime.CartesianDomain): - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()): raise RuntimeError( "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) @@ -1770,10 +1809,10 @@ def _fieldspec_list_to_value( offset_type = type_.offset_type assert isinstance(offset_type, common.Dimension) connectivity = offset_provider[offset_type.value] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return domain.insert( len(domain), - common.named_range((offset_type, connectivity.max_neighbors)), + common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)), ), type_.element_type return domain, type_ @@ -1809,7 +1848,7 @@ def closure( ) -> None: assert embedded_context.within_valid_context() offset_provider = embedded_context.offset_provider.get() - _validate_domain(domain_, offset_provider) + _validate_domain(domain_, common.offset_provider_to_type(offset_provider)) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (isinstance(out, common.Field) or is_tuple_of_field(out)): raise TypeError("'Out' needs to be a located field.") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8f842e1c13..f5625b509c 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -12,7 +12,6 @@ import functools from typing import Any, Literal, Mapping, Optional -import gt4py.next as gtx from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -23,20 +22,19 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di """ Extract horizontal domain sizes from an `offset_provider`. - Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum - value inside the neighbor table to get the size of each `neighbor_axis`. + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. """ sizes = dict[str, int]() for provider in offset_provider.values(): - if isinstance(provider, gtx.NeighborTableOffsetProvider): - assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL - assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL - sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] ) - sizes[provider.neighbor_axis.value] = max( - sizes.get(provider.neighbor_axis.value, 0), - provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) return sizes @@ -114,7 +112,7 @@ def translate( new_ranges[current_dim] = SymbolicRange.translate( self.ranges[current_dim], val.value ) - elif isinstance(nbt_provider, common.Connectivity): + elif common.is_neighbor_connectivity(nbt_provider): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -132,8 +130,8 @@ def translate( for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - old_dim = nbt_provider.origin_axis - new_dim = nbt_provider.neighbor_axis + old_dim = nbt_provider.__gt_type__().source_dim + new_dim = nbt_provider.__gt_type__().codomain assert new_dim not in new_ranges or old_dim == new_dim diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ad85d154cb..d42f961202 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -12,7 +12,7 @@ import functools import types from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import devtools @@ -127,7 +127,9 @@ def fendef( ) -def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[str, Any]): +def _deduce_domain( + domain: dict[common.Dimension, range], offset_provider_type: common.OffsetProviderType +): if isinstance(domain, UnstructuredDomain): domain_builtin = builtins.unstructured_domain elif isinstance(domain, CartesianDomain): @@ -135,7 +137,7 @@ def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[ else: domain_builtin = ( builtins.unstructured_domain - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()) + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()) else builtins.cartesian_domain ) @@ -160,7 +162,7 @@ def impl(out, *inps): elif isinstance(dom, dict): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None - dom = _deduce_domain(dom, offset_provider) + dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) closure(dom, self.fundef_dispatcher, out, [*inps]) return impl diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f84714e779..e71a24127f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -105,7 +105,7 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider: Optional[common.OffsetProvider] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes flags: Optional[Flag] = None, @@ -126,7 +126,7 @@ def apply( `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ flags = flags or cls.flags - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} if isinstance(node, (ir.Program, ir.FencilDefinition)): within_stencil = False @@ -138,7 +138,7 @@ def apply( if not ignore_tuple_size: node = itir_type_inference.infer( node, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, allow_undeclared_symbols=allow_undeclared_symbols, ) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 38ea1fd53d..824adfdd8d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -411,7 +411,7 @@ def apply( cls, node: ProgramOrExpr, within_stencil: bool | None = None, - offset_provider: common.OffsetProvider | None = None, + offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: @@ -422,9 +422,9 @@ def apply( within_stencil is not None ), "The expression's context must be specified using `within_stencil`." - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program + node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) return cls().visit(node, within_stencil=within_stencil) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index da238733da..9076bf2d3f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -11,6 +11,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import ( @@ -89,7 +90,7 @@ class FuseAsFieldOp(eve.NodeTranslator): ) >>> print( ... FuseAsFieldOp.apply( - ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) @@ -134,12 +135,14 @@ def apply( cls, node: itir.Program, *, - offset_provider, + offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, ): node = type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, ) if not uids: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 90f8a6cded..a6d39883e3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -187,7 +187,9 @@ def create_global_tmps( arguments into temporaries. """ program = infer_domain.infer_program(program, offset_provider) - program = type_inference.infer(program, offset_provider=offset_provider) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index c6e2c38b90..87b576d14d 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -17,8 +17,8 @@ class InlineScalar(eve.NodeTranslator): @classmethod - def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): - program = itir_inference.infer(program, offset_provider=offset_provider) + def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): + program = itir_inference.infer(program, offset_provider_type=offset_provider_type) return cls().visit(program) def visit_Expr(self, node: itir.Expr): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 52a452155a..ec6f89685a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -43,8 +43,8 @@ def __call__( def apply_common_transforms( ir: itir.Program | itir.FencilDefinition, *, + offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, - offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, force_inline_lambda_args=False, @@ -56,7 +56,12 @@ def apply_common_transforms( #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this if isinstance(ir, itir.FencilDefinition): ir = fencil_to_program.FencilToProgram.apply(ir) @@ -75,7 +80,7 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, @@ -89,15 +94,15 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program - inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) + inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns # a list. Such expressions must be inlined however because no backend supports such # field operators right now. inlined = fuse_as_fieldop.FuseAsFieldOp.apply( - inlined, uids=mergeasfop_uids, offset_provider=offset_provider + inlined, uids=mergeasfop_uids, offset_provider_type=offset_provider_type ) if inlined == ir: @@ -108,19 +113,21 @@ def apply_common_transforms( # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) ir = MergeLet().visit(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True) if extract_temporaries: - ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -129,7 +136,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled # type: ignore[assignment] # still a `itir.Program` @@ -156,6 +163,8 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = CollapseTuple.apply( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) # type: ignore[assignment] # type is still `itir.Program` ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py index 792bb421f1..94c962e92d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -10,6 +10,7 @@ from typing import Callable, Optional from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet @@ -75,8 +76,13 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: assert isinstance(ir, itir.FencilDefinition) + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + ir = fencil_to_program.FencilToProgram().apply(ir) icdlv_uids = eve_utils.UIDGenerator() @@ -109,7 +115,7 @@ def apply_common_transforms( # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -134,7 +140,7 @@ def apply_common_transforms( ir = CollapseTuple.apply( ir, ignore_tuple_size=True, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -149,7 +155,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled @@ -164,7 +170,7 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program ir = MergeLet().visit(ir) ir = InlineLambdas.apply( diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index ec9c3efb2b..042a86cd8e 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -64,16 +64,16 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: def _get_connectivity( applied_reduce_node: itir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: + offset_provider_type: common.OffsetProviderType, +) -> common.NeighborConnectivityType: """Return single connectivity that is compatible with the arguments of the reduce.""" if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - connectivities: list[common.Connectivity] = [] + connectivities: list[common.NeighborConnectivityType] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) + conn = offset_provider_type[o] + assert isinstance(conn, common.NeighborConnectivityType) connectivities.append(conn) if not connectivities: @@ -120,15 +120,15 @@ class UnrollReduce(PreserveLocationVisitor, NodeTranslator): uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) @classmethod - def apply(cls, node: itir.Node, **kwargs) -> itir.Node: - return cls().visit(node, **kwargs) - - def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - connectivity = _get_connectivity(node, offset_provider) - max_neighbors = connectivity.max_neighbors - has_skip_values = connectivity.has_skip_values + def apply(cls, node: itir.Node, offset_provider_type: common.OffsetProviderType) -> itir.Node: + return cls().visit(node, offset_provider_type=offset_provider_type) + + def _visit_reduce( + self, node: itir.FunCall, offset_provider_type: common.OffsetProviderType + ) -> itir.Expr: + connectivity_type = _get_connectivity(node, offset_provider_type) + max_neighbors = connectivity_type.max_neighbors + has_skip_values = connectivity_type.has_skip_values acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 66d8345b94..987eb0f308 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -155,7 +155,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> square_func_type_synthesizer = type_synthesizer.TypeSynthesizer( ... type_synthesizer=lambda base: power(base, int_type) ... ) - >>> square_func_type_synthesizer(float_type, offset_provider={}) + >>> square_func_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such @@ -169,7 +169,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... node=square_func, ... store_inferred_type_in_node=True, ... ) - >>> o_type_synthesizer(float_type, offset_provider={}) + >>> o_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type @@ -225,13 +225,15 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: def __call__( self, *args: type_synthesizer.TypeOrTypeSynthesizer, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: assert all( isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" - return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) + return_type_or_synthesizer = self.type_synthesizer( + *args, offset_provider_type=offset_provider_type + ) # return type is a typing rule by itself if isinstance(return_type_or_synthesizer, type_synthesizer.TypeSynthesizer): @@ -250,18 +252,18 @@ def __call__( def _get_dimensions_from_offset_provider( - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} - for offset_name, provider in offset_provider.items(): + for offset_name, provider in offset_provider_type.items(): dimensions[offset_name] = common.Dimension( value=offset_name, kind=common.DimensionKind.LOCAL ) if isinstance(provider, common.Dimension): dimensions[provider.value] = provider - elif isinstance(provider, common.Connectivity): - dimensions[provider.origin_axis.value] = provider.origin_axis - dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + elif isinstance(provider, common.NeighborConnectivityType): + dimensions[provider.source_dim.value] = provider.source_dim + dimensions[provider.codomain.value] = provider.codomain return dimensions @@ -318,7 +320,7 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider: common.OffsetProvider + offset_provider_type: common.OffsetProviderType #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. @@ -329,7 +331,7 @@ def apply( cls, node: T, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: @@ -340,7 +342,7 @@ def apply( node: The :class:`itir.Node` to infer the types of. Keyword Arguments: - offset_provider: Offset provider dictionary. + offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding declaration. This is useful for testing or inference on partially inferred sub-nodes. @@ -403,9 +405,9 @@ def apply( ) instance = cls( - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, dimensions=( - _get_dimensions_from_offset_provider(offset_provider) + _get_dimensions_from_offset_provider(offset_provider_type) | _get_dimensions_from_types( node.pre_walk_values() .if_isinstance(itir.Node) @@ -540,7 +542,7 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc for input_ in inputs ] stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider=self.offset_provider + *stencil_args, offset_provider_type=self.offset_provider_type ) return it_ts.StencilClosureType( @@ -632,7 +634,7 @@ def visit_FunCall( fun = self.visit(node.fun, ctx=ctx) args = self.visit(node.args, ctx=ctx) - result = fun(*args, offset_provider=self.offset_provider) + result = fun(*args, offset_provider_type=self.offset_provider_type) if isinstance(result, ObservableTypeSynthesizer): assert not result.node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 43c4465576..5be9ed7438 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -35,20 +35,20 @@ class TypeSynthesizer: - isinstance checks to determine if an object is actually (meant to be) a type synthesizer and not just any callable. - writing simple type synthesizers without cluttering the signature with the additional - offset_provider argument that is only needed by some. + offset_provider_type argument that is only needed by some. """ type_synthesizer: Callable[..., TypeOrTypeSynthesizer] def __post_init__(self): - if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: + if "offset_provider_type" not in inspect.signature(self.type_synthesizer).parameters: synthesizer = self.type_synthesizer - self.type_synthesizer = lambda *args, offset_provider: synthesizer(*args) + self.type_synthesizer = lambda *args, offset_provider_type: synthesizer(*args) def __call__( - self, *args: TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider + self, *args: TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType ) -> TypeOrTypeSynthesizer: - return self.type_synthesizer(*args, offset_provider=offset_provider) + return self.type_synthesizer(*args, offset_provider_type=offset_provider_type) TypeOrTypeSynthesizer = Union[ts.TypeSpec, TypeSynthesizer] @@ -212,7 +212,7 @@ def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) - def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def apply_lift( - *its: it_ts.IteratorType, offset_provider: common.OffsetProvider + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType ) -> it_ts.IteratorType: assert all(isinstance(it, it_ts.IteratorType) for it in its) stencil_args = [ @@ -224,7 +224,7 @@ def apply_lift( ) for it in its ] - stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) + stencil_return_type = stencil(*stencil_args, offset_provider_type=offset_provider_type) assert isinstance(stencil_return_type, ts.DataType) position_dims = its[0].position_dims if its else [] @@ -282,7 +282,7 @@ def as_fieldop( stencil: TypeSynthesizer, domain: Optional[it_ts.DomainType] = None, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result # type. In order to still allow some passes which don't need this information to run before the @@ -308,7 +308,7 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( @@ -328,8 +328,10 @@ def scan( assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL @TypeSynthesizer - def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) -> ts.DataType: - result = scan_pass(init, *its, offset_provider=offset_provider) + def apply_scan( + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType + ) -> ts.DataType: + result = scan_pass(init, *its, offset_provider_type=offset_provider_type) assert isinstance(result, ts.DataType) return result @@ -340,12 +342,12 @@ def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider: common.OffsetProvider + *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType ) -> it_ts.ListType: assert len(args) > 0 assert all(isinstance(arg, it_ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] - el_type = op(*arg_el_types, offset_provider=offset_provider) + el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) return it_ts.ListType(element_type=el_type) @@ -355,15 +357,17 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider): + def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): assert all(isinstance(arg, it_ts.ListType) for arg in args) - return op(init, *(arg.element_type for arg in args), offset_provider=offset_provider) + return op( + init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type + ) return applied_reduce @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider_type: common.OffsetProviderType) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, @@ -379,19 +383,19 @@ def apply_shift( assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( offset_axis.value, common.Dimension ) - provider = offset_provider[offset_axis.value.value] # TODO: naming - if isinstance(provider, common.Dimension): + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): pass - elif isinstance(provider, common.Connectivity): + elif isinstance(type_, common.NeighborConnectivityType): found = False for i, dim in enumerate(new_position_dims): - if dim.value == provider.origin_axis.value: + if dim.value == type_.source_dim.value: assert not found - new_position_dims[i] = provider.neighbor_axis + new_position_dims[i] = type_.codomain found = True assert found else: - raise NotImplementedError() + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 802ad2155f..69d8985beb 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -26,7 +26,6 @@ import typing from typing import Any, Iterable, Iterator, Optional -import numpy as np from typing_extensions import Self from gt4py.next import common @@ -49,47 +48,19 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: return cls(args=args, kwargs=kwargs) -@dataclasses.dataclass(frozen=True) -class CompileTimeConnectivity(common.Connectivity): - """Compile-time standin for a GTX connectivity, retaining everything except the connectivity tables.""" - - max_neighbors: int - has_skip_values: bool - origin_axis: common.Dimension - neighbor_axis: common.Dimension - index_type: type[int] | type[np.int32] | type[np.int64] - - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - raise NotImplementedError( - "A CompileTimeConnectivity instance should not call `mapped_index`." - ) - - @classmethod - def from_connectivity(cls, connectivity: common.Connectivity) -> Self: - return cls( - max_neighbors=connectivity.max_neighbors, - has_skip_values=connectivity.has_skip_values, - origin_axis=connectivity.origin_axis, - neighbor_axis=connectivity.neighbor_axis, - index_type=connectivity.index_type, - ) - - @property - def table(self) -> None: - return None - - @dataclasses.dataclass(frozen=True) class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] kwargs: dict[str, ts.TypeSpec] - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + @property + def offset_provider_type(self) -> common.OffsetProviderType: + return common.offset_provider_to_type(self.offset_provider) + @classmethod def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" @@ -98,8 +69,7 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: offset_provider = kwargs_copy.pop("offset_provider", {}) return cls( args=compile_args, - offset_provider=offset_provider, # TODO(ricoh): replace with the line below once the temporaries pass is AOT-ready. If unsure, just try it and run the tests. - # offset_provider={k: connectivity_or_dimension(v) for k, v in offset_provider.items()}, # noqa: ERA001 [commented-out-code] + offset_provider=offset_provider, column_axis=kwargs_copy.pop("column_axis", None), kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None @@ -138,18 +108,6 @@ def adapted_jit_to_aot_args_factory() -> ( return toolchain.ArgsOnlyAdapter(jit_to_aot_args) -def connectivity_or_dimension( - some_offset_provider: common.Connectivity | common.Dimension, -) -> CompileTimeConnectivity | common.Dimension: - match some_offset_provider: - case common.Dimension(): - return some_offset_provider - case common.Connectivity(): - return CompileTimeConnectivity.from_connectivity(some_offset_provider) - case _: - raise ValueError - - def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: for element in tuple_arg: match element: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index cc57c137bf..b2aea05641 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -12,7 +12,6 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, concepts from gt4py.eve.utils import UIDGenerator -from gt4py.next import common from gt4py.next.program_processors.codegens.gtfn import gtfn_ir, gtfn_ir_common from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ( AssignStmt, @@ -84,54 +83,9 @@ def _is_reduce(node: gtfn_ir.FunCall) -> TypeGuard[gtfn_ir.FunCall]: ) -def _get_connectivity( - applied_reduce_node: gtfn_ir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: - """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - - connectivities: list[common.Connectivity] = [] - for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) - connectivities.append(conn) - - if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") - - if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: - # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") - return connectivities[0] - - # TODO: end of code clone -def _make_dense_acess( - shift_call: gtfn_ir.FunCall, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="deref"), - args=[ - gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="shift"), args=[*shift_call.args, nbh_iter] - ) - ], - ) - - -def _make_sparse_acess( - field_ref: gtfn_ir_common.SymRef, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="tuple_get"), - args=[nbh_iter, gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref])], - ) - - class PlugInCurrentIdx(NodeTranslator): def visit_SymRef( self, node: gtfn_ir_common.SymRef @@ -225,32 +179,6 @@ def _expand_symref( ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) - def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - - connectivity = _get_connectivity(node, offset_provider) - - args = node.args - # do the following transformations to the node arguments - # dense fields: shift(dense_f, X2Y) -> deref(shift(dense_f, X2Y, nbh_iterator) - # sparse_fields: sparse_f -> tuple_get(nbh_iterator, deref(sparse_f))) - new_args = [] - nbh_iter = gtfn_ir_common.SymRef(id="nbh_iter") - for arg in args: - if isinstance(arg, gtfn_ir.FunCall) and arg.fun.id == "shift": # type: ignore - new_args.append(_make_dense_acess(arg, nbh_iter)) - if isinstance(arg, gtfn_ir_common.SymRef): - new_args.append(_make_sparse_acess(arg, nbh_iter)) - - red_idx = self.uids.sequential_id(prefix="red") - if isinstance(node.fun.args[0], gtfn_ir.Lambda): # type: ignore - self._expand_lambda(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - elif isinstance(node.fun.args[0], gtfn_ir_common.SymRef): # type: ignore - self._expand_symref(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - - return gtfn_ir_common.SymRef(id=red_idx) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: if any(isinstance(arg, gtfn_ir.Lambda) for arg in node.args): # do not try to lower constructs that take lambdas as argument to something more readable @@ -278,7 +206,9 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if _is_reduce(node): - return self.handle_Reduction(node, **kwargs) + raise AssertionError( + "Not implemented. The code-path was removed as it was not actively used and tested." + ) if isinstance(node.fun, gtfn_ir_common.SymRef) and node.fun.id == "make_tuple": tupl_id = self.uids.sequential_id(prefix="tupl") tuple_fun = self.commit_args(node, tupl_id, "make_tuple", **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ce459f7970..f1649112a7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -82,7 +82,7 @@ def _process_regular_arguments( self, program: itir.FencilDefinition | itir.Program, arg_types: tuple[ts.TypeSpec, ...], - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -104,22 +104,22 @@ def _process_regular_arguments( ): # translate sparse dimensions to tuple dtype dim_name = dim.value - connectivity = offset_provider[dim_name] - assert isinstance(connectivity, common.Connectivity) + connectivity = offset_provider_type[dim_name] + assert isinstance(connectivity, common.NeighborConnectivityType) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider: dict[str, common.Connectivity | common.Dimension] + self, offset_provider_type: common.OffsetProviderType ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] - for name, connectivity in offset_provider.items(): - if isinstance(connectivity, common.Connectivity): - if connectivity.index_type not in [np.int32, np.int64]: + for name, connectivity_type in offset_provider_type.items(): + if isinstance(connectivity_type, common.NeighborConnectivityType): + if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) @@ -129,15 +129,8 @@ def _process_connectivity_args( interface.Parameter( name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(), type_=ts.FieldType( - dims=[ - connectivity.origin_axis, - common.Dimension( - name, kind=common.DimensionKind.LOCAL - ), # TODO(havogt): we should not use the name of the offset as the name of the local dimension - ], - dtype=ts.ScalarType( - type_translation.get_scalar_kind(connectivity.index_type) - ), + dims=list(connectivity_type.domain), + dtype=type_translation.from_dtype(connectivity_type.dtype), ), ) ) @@ -145,19 +138,19 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity.origin_axis.value}_t, " - f"generated::{name}_t, {connectivity.max_neighbors}" + f"generated::{connectivity_type.source_dim.value}_t, " + f"generated::{name}_t, {connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity, common.Dimension): + elif isinstance(connectivity_type, common.Dimension): pass else: raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"got '{type(connectivity).__name__}'." + f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', " + f"got '{type(connectivity_type).__name__}'." ) return parameters, arg_exprs @@ -165,7 +158,7 @@ def _process_connectivity_args( def _preprocess_program( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, @@ -194,7 +187,7 @@ def _preprocess_program( def generate_stencil_source( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: if self.enable_itir_transforms: @@ -204,7 +197,9 @@ def generate_stencil_source( new_program = program gtfn_ir = GTFN_lowering.apply( - new_program, offset_provider=offset_provider, column_axis=column_axis + new_program, + offset_provider_type=common.offset_provider_to_type(offset_provider), + column_axis=column_axis, ) if self.use_imperative_backend: @@ -224,13 +219,13 @@ def __call__( # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) regular_parameters, regular_args_expr = self._process_regular_arguments( - program, inp.args.args, inp.args.offset_provider + program, inp.args.args, inp.args.offset_provider_type ) # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.args.offset_provider + inp.args.offset_provider_type ) # combine into a format that is aligned with what the backend expects diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index bc2bd645e8..129d81d6f9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -159,7 +159,7 @@ def _collect_dimensions_from_domain( def _collect_offset_definitions( node: itir.Node, grid_type: common.GridType, - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: used_offset_tags: set[itir.OffsetLiteral] = ( node.walk_values() @@ -167,13 +167,13 @@ def _collect_offset_definitions( .filter(lambda offset_literal: isinstance(offset_literal.value, str)) .getattr("value") ).to_set() - if not used_offset_tags.issubset(set(offset_provider.keys())): + if not used_offset_tags.issubset(set(offset_provider_type.keys())): raise AssertionError("ITIR contains an offset tag without a corresponding offset provider.") offset_definitions = {} - for offset_name, dim_or_connectivity in offset_provider.items(): - if isinstance(dim_or_connectivity, common.Dimension): - dim: common.Dimension = dim_or_connectivity + for offset_name, dim_or_connectivity_type in offset_provider_type.items(): + if isinstance(dim_or_connectivity_type, common.Dimension): + dim: common.Dimension = dim_or_connectivity_type if grid_type == common.GridType.CARTESIAN: # create alias from offset to dimension offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) @@ -201,12 +201,13 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance(dim_or_connectivity, common.Connectivity): + elif isinstance( + connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType + ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) - connectivity: common.Connectivity = dim_or_connectivity - for dim in [connectivity.origin_axis, connectivity.neighbor_axis]: + for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: raise NotImplementedError() offset_definitions[dim.value] = TagDefinition( @@ -323,7 +324,7 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): } _unary_op_map: ClassVar[dict[str, str]] = {"not_": "!"} - offset_provider: dict + offset_provider_type: common.OffsetProviderType column_axis: Optional[common.Dimension] grid_type: common.GridType @@ -338,18 +339,18 @@ def apply( cls, node: itir.Program, *, - offset_provider: dict, + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> Program: if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) grid_type = _get_gridtype(node.body) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( - offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type + offset_provider_type=offset_provider_type, column_axis=column_axis, grid_type=grid_type ).visit(node) def visit_Sym(self, node: itir.Sym, **kwargs: Any) -> Sym: @@ -484,8 +485,8 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: if "stencil" in kwargs: shift_offsets = self._collect_offset_or_axis_node(itir.OffsetLiteral, kwargs["stencil"]) for o in shift_offsets: - if o in self.offset_provider and isinstance( - self.offset_provider[o], common.Connectivity + if o in self.offset_provider_type and isinstance( + self.offset_provider_type[o], common.NeighborConnectivityType ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( @@ -679,7 +680,7 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { **_collect_dimensions_from_domain(node.body), - **_collect_offset_definitions(node, self.grid_type, self.offset_provider), + **_collect_offset_definitions(node, self.grid_type, self.offset_provider_type), } return Program( id=SymbolName(node.id), diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index db0df7d121..56ba08015b 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -12,6 +12,7 @@ import dace import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, utils as gtx_utils from . import utility as dace_utils @@ -65,8 +66,8 @@ def _get_args( def _ensure_is_on_device( - connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType -) -> np.typing.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: dace.dtypes.DeviceType +) -> core_defs.NDArrayObject: if device == dace.dtypes.DeviceType.GPU: if not isinstance(connectivity_arg, cp.ndarray): warnings.warn( @@ -78,7 +79,7 @@ def _ensure_is_on_device( def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): @@ -103,7 +104,7 @@ def _get_shape_args( def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: stride_args = {} for name, value in args.items(): @@ -134,7 +135,7 @@ def get_sdfg_conn_args( sdfg: dace.SDFG, offset_provider: gtx_common.OffsetProvider, on_gpu: bool, -) -> dict[str, np.typing.NDArray]: +) -> dict[str, core_defs.NDArrayObject]: """ Extracts the connectivity tables that are used in the sdfg and ensures that the memory buffers are allocated for the target device. @@ -142,11 +143,11 @@ def get_sdfg_conn_args( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU connectivity_args = {} - for offset, connectivity in dace_utils.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - param = dace_utils.connectivity_identifier(offset) - if param in sdfg.arrays: - connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + for offset, connectivity in offset_provider.items(): + if gtx_common.is_neighbor_table(connectivity): + param = dace_utils.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.ndarray, device) return connectivity_args diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index bc01e2abda..29395a30c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -79,19 +79,18 @@ def debug_info( return default -def filter_connectivities( - offset_provider: gtx_common.OffsetProvider, -) -> dict[str, gtx_common.Connectivity]: +def filter_connectivity_types( + offset_provider_type: gtx_common.OffsetProviderType, +) -> dict[str, gtx_common.NeighborConnectivityType]: """ - Filter offset providers of type `Connectivity`. + Filter offset provider types of type `NeighborConnectivityType`. In other words, filter out the cartesian offset providers. - Returns a new dictionary containing only `Connectivity` values. """ return { - offset: table - for offset, table in offset_provider.items() - if isinstance(table, gtx_common.Connectivity) + offset: conn + for offset, conn in offset_provider_type.items() + if isinstance(conn, gtx_common.NeighborConnectivityType) } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 73b6e2ed4c..74142dec66 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -527,14 +527,14 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider = self.subgraph_builder.get_offset_provider_type(offset) + assert isinstance(offset_provider, gtx_common.NeighborConnectivityType) it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.neighbor_axis in it.dimensions - assert offset_provider.origin_axis in it.indices - origin_index = it.indices[offset_provider.origin_axis] + assert offset_provider.codomain in it.dimensions + assert offset_provider.source_dim in it.indices + origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) @@ -561,7 +561,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: subset=sbs.Range.from_string( ",".join( it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis + if dim != offset_provider.codomain else f"0:{size}" for dim, size in zip(it.dimensions, field_desc.shape, strict=True) ) @@ -657,7 +657,9 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: tasklet_expression = f"{output_connector} = {fun_python_code}" input_args = [self.visit(arg) for arg in node.args] - input_connectivities: dict[gtx_common.Dimension, gtx_common.Connectivity] = {} + input_connectivity_types: dict[ + gtx_common.Dimension, gtx_common.NeighborConnectivityType + ] = {} for input_arg in input_args: assert isinstance(input_arg.gt_dtype, itir_ts.ListType) assert input_arg.gt_dtype.offset_type is not None @@ -665,11 +667,11 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: if offset_type == _CONST_DIM: # this input argument is the result of `make_const_list` continue - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) - input_connectivities[offset_type] = offset_provider + offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType) + input_connectivity_types[offset_type] = offset_provider_t - if len(input_connectivities) == 0: + if len(input_connectivity_types) == 0: raise ValueError(f"Missing information on local dimension for map node {node}.") # GT4Py guarantees that all connectivities used to generate lists of neighbors @@ -678,14 +680,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: len( set( (conn.has_skip_values, conn.max_neighbors) - for conn in input_connectivities.values() + for conn in input_connectivity_types.values() ) ) != 1 ): raise ValueError("Unexpected arguments to map expression with different neighborhood.") - offset_type, offset_provider = next(iter(input_connectivities.items())) - local_size = offset_provider.max_neighbors + offset_type, offset_provider_type = next(iter(input_connectivity_types.items())) + local_size = offset_provider_type.max_neighbors map_index = dace_gtir_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. @@ -717,14 +719,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) result_node = self.state.add_access(result) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) connectivity_slice = self._construct_local_view( MemletExpr( @@ -733,7 +735,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: element_type=node.type.element_type, offset_type=offset_type ), subset=sbs.Range.from_string( - f"{origin_map_index}, 0:{offset_provider.max_neighbors}" + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) ) @@ -774,7 +776,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: def _make_reduce_with_skip_values( self, input_expr: ValueExpr | MemletExpr, - offset_provider: gtx_common.Connectivity, + offset_provider_type: gtx_common.NeighborConnectivityType, reduce_init: SymbolExpr, reduce_identity: SymbolExpr, reduce_wcr: str, @@ -792,7 +794,7 @@ def _make_reduce_with_skip_values( corresponding neighbor index in the connectivity table is valid, or the identity value if the neighbor index is missing. """ - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( isinstance(input_expr.gt_dtype, itir_ts.ListType) @@ -815,7 +817,7 @@ def _make_reduce_with_skip_values( f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." ) local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider.max_neighbors + assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors # we lower the reduction map with WCR out memlet in a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name("reduce_with_skip_values")) @@ -853,7 +855,7 @@ def _make_reduce_with_skip_values( # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. st_reduce.add_mapped_tasklet( name="reduce_with_skip_values", - map_ranges={"i": f"0:{offset_provider.max_neighbors}"}, + map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, inputs={ "__val": dace.Memlet(data="values", subset="i"), "__neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), @@ -882,7 +884,7 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), nsdfg_node, "neighbor_indices", ) @@ -910,12 +912,17 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: self._make_reduce_with_skip_values( - input_expr, offset_provider, reduce_init, reduce_identity, reduce_wcr, result_node + input_expr, + offset_provider_type, + reduce_init, + reduce_identity, + reduce_wcr, + result_node, ) else: @@ -1082,16 +1089,16 @@ def _make_dynamic_neighbor_offset( def _make_unstructured_shift( self, it: IteratorExpr, - connectivity: gtx_common.Connectivity, + connectivity: gtx_common.NeighborConnectivityType, offset_table_node: dace.nodes.AccessNode, offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.neighbor_axis in it.dimensions - neighbor_dim = connectivity.neighbor_axis + assert connectivity.codomain in it.dimensions + neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices - origin_dim = connectivity.origin_axis + origin_dim = connectivity.source_dim assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -1132,7 +1139,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: assert isinstance(offset_provider_arg, gtir.OffsetLiteral) offset = offset_provider_arg.value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) # second argument should be the offset value, which could be a symbolic expression or a dynamic offset offset_expr = ( SymbolExpr(offset_value_arg.value, IndexDType) @@ -1140,8 +1147,8 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: else self.visit(offset_value_arg) ) - if isinstance(offset_provider, gtx_common.Dimension): - return self._make_cartesian_shift(it, offset_provider, offset_expr) + if isinstance(offset_provider_type, gtx_common.Dimension): + return self._make_cartesian_shift(it, offset_provider_type, offset_expr) else: # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, @@ -1151,7 +1158,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: offset_table_node = self.state.add_access(offset_table) return self._make_unstructured_shift( - it, offset_provider, offset_table_node, offset_expr + it, offset_provider_type, offset_table_node, offset_expr ) def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index ad8f490f12..52284edfac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -41,7 +41,7 @@ class DataflowBuilder(Protocol): """Visitor interface to build a dataflow subgraph.""" @abc.abstractmethod - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: ... @abc.abstractmethod def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: ... @@ -155,7 +155,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): from where to continue building the SDFG. """ - offset_provider: gtx_common.OffsetProvider + offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") @@ -164,8 +164,8 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: - return self.offset_provider[offset] + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: + return self.offset_provider_type[offset] def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -195,10 +195,10 @@ def _make_array_shape_and_strides( Two lists of symbols, one for the shape and the other for the strides of the array. """ dc_dtype = gtir_builtin_translators.INDEX_DTYPE - neighbor_tables = dace_utils.filter_connectivities(self.offset_provider) + neighbor_table_types = dace_utils.filter_connectivity_types(self.offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + neighbor_table_types[dim.value].max_neighbors if dim.kind == gtx_common.DimensionKind.LOCAL else dace.symbol(dace_utils.field_size_symbol_name(name, i), dc_dtype) ) @@ -374,13 +374,12 @@ def _add_sdfg_params( self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables - for offset, offset_provider in dace_utils.filter_connectivities( - self.offset_provider + for offset, connectivity_type in dace_utils.filter_connectivity_types( + self.offset_provider_type ).items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) + scalar_type = tt.from_dtype(connectivity_type.dtype) gt_type = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) # We store all connectivity tables as transient arrays here; later, while building # the field operator expressions, we change to non-transient (i.e. allocated externally) @@ -585,7 +584,7 @@ def visit_Lambda( } # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider, lambda_symbols) + lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -630,7 +629,7 @@ def _flatten_tuples( ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) - for offset in dace_utils.filter_connectivities(self.offset_provider) + for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) } input_memlets = {} @@ -778,7 +777,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, - offset_provider: gtx_common.OffsetProvider, + offset_provider_type: gtx_common.OffsetProviderType, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -788,15 +787,15 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG - offset_provider: The definitions of offset providers used by the program node + offset_provider_type: The definitions of offset providers used by the program node Returns: An SDFG in the DaCe canonical form (simplified) """ - ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider) + sdfg_genenerator = GTIRToSDFG(offset_provider_type) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index aa4fd0cd3e..40d44f5ab0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -52,7 +52,9 @@ def generate_sdfg( on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) - sdfg = gtir_sdfg.build_sdfg_from_gtir(ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) @@ -75,7 +77,7 @@ def __call__( sdfg = self.generate_sdfg( program, - inp.args.offset_provider, + inp.args.offset_provider, # TODO(havogt): should be offset_provider_type once the transformation don't require run-time info inp.args.column_axis, auto_opt=self.auto_optimize, on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fc2772027e..ef09cf51cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -9,7 +9,7 @@ import dataclasses import warnings from collections import OrderedDict -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from dataclasses import field from inspect import currentframe, getframeinfo from pathlib import Path @@ -38,7 +38,7 @@ def preprocess_program( program: itir.FencilDefinition, - offset_provider: Mapping[str, Any], + offset_provider_type: common.OffsetProviderType, lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ @@ -51,13 +51,13 @@ def preprocess_program( common_subexpression_elimination=False, force_inline_lambda_args=True, lift_mode=lift_mode, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, symbolic_domain_sizes=symbolic_domain_sizes, temporary_extraction_heuristics=temporary_extraction_heuristics, unroll_reduce=unroll_reduce, ) - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) if isinstance(node, itir.Program): fencil_definition = program_to_fencil.program_to_fencil(node) @@ -72,7 +72,7 @@ def preprocess_program( def build_sdfg_from_itir( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, @@ -109,10 +109,18 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program, tmps = preprocess_program( - program, offset_provider, lift_mode, symbolic_domain_sizes, temporary_extraction_heuristics + program, + offset_provider_type, + lift_mode, + symbolic_domain_sizes, + temporary_extraction_heuristics, ) sdfg_genenerator = ItirToSDFG( - list(arg_types), offset_provider, tmps, use_field_canonical_representation, column_axis + list(arg_types), + offset_provider_type, + tmps, + use_field_canonical_representation, + column_axis, ) sdfg = sdfg_genenerator.visit(program) if sdfg is None: @@ -186,14 +194,12 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: raise ValueError( "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." ) - offset_provider = ( - self.connectivities | self._implicit_offset_provider - ) # tables are None at this point + offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] self.itir, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, column_axis=kwargs.get("column_axis", None), ) self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ @@ -238,7 +244,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: sdfg.offset_providers_per_input_field = {} itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider=offset_provider + self.itir, offset_provider_type=offset_provider_type ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) for closure in itir_tmp_fencil.closures: @@ -267,7 +273,7 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ the offset providers are not part of GT4Py Program's arguments. Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. """ - offset_provider = self.connectivities + offset_provider_type = self.connectivities # Define DaCe symbols connectivity_table_size_symbols = { @@ -276,9 +282,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -288,9 +294,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -298,8 +304,8 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Define the storage location (e.g. CPU, GPU) of the connectivity tables if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider.items(): # type: ignore[union-attr] - if not hasattr(v, "table"): + for k, v in offset_provider_type.items(): # type: ignore[union-attr] + if not isinstance(v, common.NeighborConnectivityType): continue if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: Program.connectivity_tables_data_descriptors["storage"] = ( @@ -311,12 +317,15 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Build the closure dictionary closure_dict = {} - for k, v in offset_provider.items(): # type: ignore[union-attr] + for k, v in offset_provider_type.items(): # type: ignore[union-attr] conn_id = dace_utils.connectivity_identifier(k) - if hasattr(v, "table") and conn_id in self.sdfg_closure_vars["sdfg.arrays"]: + if ( + isinstance(v, common.NeighborConnectivityType) + and conn_id in self.sdfg_closure_vars["sdfg.arrays"] + ): if conn_id not in Program.connectivity_tables_data_descriptors: Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.index_type == np.int64 else dace.int32, + dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, shape=[ symbols[dace_utils.field_size_symbol_name(conn_id, 0)], symbols[dace_utils.field_size_symbol_name(conn_id, 1)], diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a0f4b83d35..823943cfd5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -7,14 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Any, Mapping, Optional, Sequence, cast +from typing import Optional, Sequence, cast import dace from dace.sdfg.state import LoopRegion import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind -from gt4py.next.common import Connectivity +from gt4py.next import Dimension, DimensionKind, common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef @@ -91,7 +90,10 @@ def _get_scan_dim( def _make_array_shape_and_strides( - name: str, dims: Sequence[Dimension], offset_provider: Mapping[str, Any], sort_dims: bool + name: str, + dims: Sequence[Dimension], + offset_provider_type: common.OffsetProviderType, + sort_dims: bool, ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -106,10 +108,10 @@ def _make_array_shape_and_strides( """ dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - neighbor_tables = dace_utils.filter_connectivities(offset_provider) + connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + connectivity_types[dim.value].max_neighbors if dim.kind == DimensionKind.LOCAL # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) @@ -144,21 +146,21 @@ class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] column_axis: Optional[Dimension] - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType unique_id: int use_field_canonical_representation: bool def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, Connectivity | Dimension], + offset_provider_type: common.OffsetProviderType, tmps: list[itir.Temporary], use_field_canonical_representation: bool, column_axis: Optional[Dimension] = None, ): self.param_types = param_types self.column_axis = column_axis - self.offset_provider = offset_provider + self.offset_provider_type = offset_provider_type self.storage_types = {} self.tmps = tmps self.use_field_canonical_representation = use_field_canonical_representation @@ -166,7 +168,7 @@ def __init__( def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): if isinstance(type_, ts.FieldType): shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider, sort_dimensions + name, type_.dims, self.offset_provider_type, sort_dimensions ) dtype = dace_utils.as_dace_type(type_.dtype) sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -255,7 +257,7 @@ def get_output_nodes( # Visit output node again to generate the corresponding tasklet context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen( - self.offset_provider, context, self.use_field_canonical_representation + self.offset_provider_type, context, self.use_field_canonical_representation ) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -266,7 +268,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): entry_state = program_sdfg.add_state("program_entry", is_start_block=True) # Filter neighbor tables from offset providers. - neighbor_tables = get_used_connectivities(node, self.offset_provider) + connectivity_types = get_used_connectivities(node, self.offset_provider_type) # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): @@ -285,11 +287,10 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state = entry_state # Add connectivities as SDFG storages. - for offset, offset_provider in neighbor_tables.items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + for offset, connectivity_type in connectivity_types.items(): + scalar_type = tt.from_dtype(connectivity_type.dtype) type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) self.add_storage( program_sdfg, @@ -362,7 +363,7 @@ def visit_StencilClosure( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -568,7 +569,7 @@ def _visit_scan_stencil_closure( ) assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -673,7 +674,7 @@ def _visit_scan_stencil_closure( connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] lambda_context, lambda_outputs = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, lambda_domain, input_arrays, connectivity_arrays, @@ -738,7 +739,7 @@ def _visit_parallel_stencil_closure( tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... ], ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -762,7 +763,7 @@ def _visit_parallel_stencil_closure( context, results = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, index_domain, input_arrays, connectivity_arrays, @@ -788,7 +789,7 @@ def _visit_domain( lower_bound = named_range.args[1] upper_bound = named_range.args[2] translator = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, context, self.use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 991053b4a5..2b2669187a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,8 +19,8 @@ import gt4py.eve.codegen from gt4py import eve -from gt4py.next import Dimension -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity +from gt4py.next import common +from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import FunCall, Lambda from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -187,15 +187,15 @@ def _visit_lift_in_neighbors_reduction( transformer: PythonTaskletCodegen, node: itir.FunCall, node_args: Sequence[IteratorExpr | list[ValueExpr]], - offset_provider: Connectivity, + connectivity_type: common.NeighborConnectivityType, map_entry: dace.nodes.MapEntry, map_exit: dace.nodes.MapExit, neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: assert transformer.context.reduce_identity is not None - neighbor_dim = offset_provider.neighbor_axis.value - origin_dim = offset_provider.origin_axis.value + neighbor_dim = connectivity_type.codomain.value + origin_dim = connectivity_type.source_dim.value lifted_args: list[IteratorExpr | ValueExpr] = [] for arg in node_args: @@ -232,7 +232,7 @@ def _visit_lift_in_neighbors_reduction( assert isinstance(y, ValueExpr) input_nodes[x] = y.value - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider) + neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -294,7 +294,7 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) - if offset_provider.has_skip_values: + if connectivity_type.has_skip_values: # check neighbor validity on if/else inter-state edge # use one branch for connectivity case start_state = lift_context.body.add_state_before( @@ -333,8 +333,8 @@ def builtin_neighbors( assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value assert isinstance(offset_dim, str) - offset_provider = transformer.offset_provider[offset_dim] - if not isinstance(offset_provider, Connectivity): + connectivity_type = transformer.offset_provider_type[offset_dim] + if not isinstance(connectivity_type, common.NeighborConnectivityType): raise NotImplementedError( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) @@ -351,7 +351,7 @@ def builtin_neighbors( iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[offset_provider.origin_axis.value] + origin_index_node = iterator.indices[connectivity_type.source_dim.value] assert transformer.context.reduce_identity is not None assert transformer.context.reduce_identity.dtype == iterator.dtype @@ -361,7 +361,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_value_var, dtype=iterator.dtype, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) @@ -375,7 +375,7 @@ def builtin_neighbors( neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") me, mx = state.add_map( f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{offset_provider.max_neighbors}"}, + ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, debuginfo=di, ) @@ -414,7 +414,7 @@ def builtin_neighbors( transformer, lift_node, lift_args, - offset_provider, + connectivity_type, me, mx, neighbor_index_node, @@ -423,13 +423,13 @@ def builtin_neighbors( else: sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v" + connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" data_access_tasklet = state.add_tasklet( "data_access", code=f"__data = __field[{data_access_index}] " + ( f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values + if connectivity_type.has_skip_values else "" ), inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, @@ -445,7 +445,7 @@ def builtin_neighbors( ) for dim in iterator.dimensions: connector = f"{dim}_v" - if dim == offset_provider.neighbor_axis.value: + if dim == connectivity_type.codomain.value: state.add_edge( neighbor_index_node, None, @@ -470,7 +470,7 @@ def builtin_neighbors( src_conn="__data", ) - if not offset_provider.has_skip_values: + if not connectivity_type.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] else: """ @@ -483,7 +483,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_valid_var, dtype=dace.dtypes.bool, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) @@ -572,7 +572,7 @@ def build_if_state(arg, state): symbol_map = copy.deepcopy(transformer.context.symbol_map) node_context = Context(sdfg, state, symbol_map) node_taskgen = PythonTaskletCodegen( - transformer.offset_provider, + transformer.offset_provider_type, node_context, transformer.use_field_canonical_representation, ) @@ -884,21 +884,12 @@ def visit_SymRef(self, node: itir.SymRef): ) +@dataclasses.dataclass class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType context: Context use_field_canonical_representation: bool - def __init__( - self, - offset_provider: dict[str, Any], - context: Context, - use_field_canonical_representation: bool, - ): - self.offset_provider = offset_provider - self.context = context - self.use_field_canonical_representation = use_field_canonical_representation - def get_sorted_field_dimensions(self, dims: Sequence[str]): return sorted(dims) if self.use_field_canonical_representation else dims @@ -914,7 +905,7 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - get_used_connectivities(node, self.offset_provider) if use_neighbor_tables else {} + get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} ) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() @@ -974,7 +965,7 @@ def visit_Lambda( reduce_identity=self.context.reduce_identity, ) lambda_taskgen = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, lambda_context, self.use_field_canonical_representation, ) @@ -1066,7 +1057,7 @@ def _visit_call(self, node: itir.FunCall): store, self.context.body.arrays[store] ) - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider) + neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) for offset in neighbor_tables.keys(): var = dace_utils.connectivity_identifier(offset) nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) @@ -1136,12 +1127,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] assert len(dims_not_indexed) == 1 offset = dims_not_indexed[0] - offset_provider = self.offset_provider[offset] - neighbor_dim = offset_provider.neighbor_axis.value + offset_provider_type = self.offset_provider_type[offset] + assert isinstance(offset_provider_type, common.NeighborConnectivityType) + neighbor_dim = offset_provider_type.codomain.value result_name = unique_var_name() self.context.body.add_array( - result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True + result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True ) result_array = self.context.body.arrays[result_name] result_node = self.context.state.add_access(result_name, debuginfo=di) @@ -1158,7 +1150,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # we create a mapped tasklet for array slicing index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider.max_neighbors}"} + map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} src_subset = ",".join( [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) @@ -1212,27 +1204,30 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: offset_node = self.visit(tail[1])[0] assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - if isinstance(self.offset_provider[offset_dim], Connectivity): - offset_provider = self.offset_provider[offset_dim] + if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): + offset_provider_type = cast( + common.NeighborConnectivityType, self.offset_provider_type[offset_dim] + ) # ensured by condition connectivity = self.context.state.add_access( dace_utils.connectivity_identifier(offset_dim), debuginfo=di ) - shifted_dim = offset_provider.origin_axis.value - target_dim = offset_provider.neighbor_axis.value + shifted_dim_tag = offset_provider_type.source_dim.value + target_dim_tag = offset_provider_type.codomain.value args = [ ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), + ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" else: - assert isinstance(self.offset_provider[offset_dim], Dimension) + shifted_dim = self.offset_provider_type[offset_dim] + assert isinstance(shifted_dim, common.Dimension) - shifted_dim = self.offset_provider[offset_dim].value - target_dim = shifted_dim - args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] + shifted_dim_tag = shifted_dim.value + target_dim_tag = shifted_dim_tag + args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" @@ -1241,8 +1236,8 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim] - shifted_index[target_dim] = shifted_value + del shifted_index[shifted_dim_tag] + shifted_index[target_dim_tag] = shifted_value return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) @@ -1506,7 +1501,7 @@ def is_scan(node: itir.Node) -> bool: def closure_to_tasklet_sdfg( node: itir.StencilClosure, - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, domain: dict[str, str], inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], @@ -1547,7 +1542,9 @@ def closure_to_tasklet_sdfg( body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen(offset_provider, context, use_field_canonical_representation) + translator = PythonTaskletCodegen( + offset_provider_type, context, use_field_canonical_representation + ) args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d367eb0883..72bb32f003 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -7,21 +7,21 @@ # SPDX-License-Identifier: BSD-3-Clause import itertools -from typing import Any, Mapping +from typing import Any import dace import gt4py.next.iterator.ir as itir from gt4py import eve -from gt4py.next.common import Connectivity +from gt4py.next import common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.program_processors.runners.dace_common import utility as dace_utils def get_used_connectivities( - node: itir.Node, offset_provider: Mapping[str, Any] -) -> dict[str, Connectivity]: - connectivities = dace_utils.filter_connectivities(offset_provider) + node: itir.Node, offset_provider_type: common.OffsetProviderType +) -> dict[str, common.NeighborConnectivityType]: + connectivities = dace_utils.filter_connectivity_types(offset_provider_type) offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 740f1979cd..653ed4719d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -52,7 +52,7 @@ def generate_sdfg( self, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> dace.SDFG: on_gpu = ( @@ -64,7 +64,7 @@ def generate_sdfg( return build_sdfg_from_itir( program, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, @@ -87,7 +87,7 @@ def __call__( sdfg = self.generate_sdfg( program, inp.args.args, - inp.args.offset_provider, + common.offset_provider_to_type(inp.args.offset_provider), inp.args.column_axis, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 965c6417b2..1f3778f227 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,14 +12,12 @@ import diskcache import factory -import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind @@ -63,8 +61,8 @@ def decorated_program( def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType +) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -79,17 +77,17 @@ def _ensure_is_on_device( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[npt.NDArray, tuple[int, ...]]]: +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings - args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] for name, conn in offset_provider.items(): if isinstance(conn, common.Connectivity): - if not isinstance(conn, common.NeighborTable): + if not common.is_neighbor_table(conn): raise NotImplementedError( "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.table, device) + conn_arg = _ensure_is_on_device(conn.ndarray, device) args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass @@ -125,7 +123,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: the program, sorted offset_provider, and column_axis. """ program: itir.FencilDefinition | itir.Program = inp.data - offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis program_hash = utils.content_hash( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 4d518d7fcc..1dd568b95a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -94,7 +94,7 @@ def fencil_generator( ir: itir.Program | itir.FencilDefinition, debug: bool, use_embedded: bool, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, transforms: itir_transforms.ITIRTransform, ) -> stages.CompiledProgram: """ @@ -111,7 +111,15 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash( + ( + ir, + transforms, + debug, + use_embedded, + tuple(common.offset_provider_to_type(offset_provider).items()), + ) + ) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") @@ -151,7 +159,9 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as source_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: source_file_name = source_file.name if debug: print(source_file_name) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0827d99cdc..fa8c9b9ab1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -63,6 +63,7 @@ class DimensionType(TypeSpec): @dataclass(frozen=True) class OffsetType(TypeSpec): + # TODO(havogt): replace by ConnectivityType source: func_common.Dimension target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 1da34db3c0..f5646c71e4 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,30 +6,32 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -from typing import Optional from types import ModuleType +from typing import Optional + +import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend -from gt4py.next.otf import arguments +from gt4py.next import backend as next_backend, common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + E2V, + E2VDim, + Edge, + Vertex, exec_alloc_descriptor, mesh_descriptor, - Vertex, - Edge, - E2V, ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, - laplap_program, lap_ref, + laplap_program, ) + try: import dace from gt4py.next.program_processors.runners.dace import ( @@ -57,25 +59,20 @@ def test_sdfgConvertible_laplap(cartesian_case): in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() - connectivities = {} # Dict of NeighborOffsetProviders, where self.table = None - for k, v in cartesian_case.offset_provider.items(): - if hasattr(v, "table"): - connectivities[k] = arguments.CompileTimeConnectivity( - v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype - ) - else: - connectivities[k] = v - # Test DaCe closure support @dace.program def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(in_field, tmp_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + in_field, tmp_field + ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(tmp_field, out_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + tmp_field, out_field + ) sdfg() @@ -130,13 +127,13 @@ def sdfg( a, out, offset_provider=offset_provider ) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[0, 1], [1, 2], [2, 0]]), Edge, Vertex, 2, False - ) - connectivities = {} - connectivities["E2V"] = arguments.CompileTimeConnectivity( - e2v.max_neighbors, e2v.has_skip_values, e2v.origin_axis, e2v.neighbor_axis, e2v.table.dtype + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, ) + connectivities = {"E2V": e2v.__gt_type__()} offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -144,6 +141,9 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) + e2v_ndarray_copy = ( + e2v.ndarray.copy() + ) # otherwise DaCe complains about the gt4py custom allocated view # This is a low level interface to call the compiled SDFG. # It is not supposed to be used in user code. # The high level interface should be provided by a DaCe Orchestrator, @@ -155,21 +155,21 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[1, 0], [2, 1], [0, 2]]), Edge, Vertex, 2, False + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[1, 0], [2, 1], [0, 2]]), + allocator=allocator, ) + e2v_ndarray_copy = e2v.ndarray.copy() offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) cSDFG( a, @@ -177,17 +177,13 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) def get_stride_from_numpy_to_dace(numpy_array: np.ndarray, axis: int) -> int: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index c64efb27d2..794dd06709 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -152,7 +152,10 @@ def num_edges(self) -> int: ... def num_levels(self) -> int: ... @property - def offset_provider(self) -> dict[str, common.Connectivity]: ... + def offset_provider(self) -> common.OffsetProvider: ... + + @property + def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh() -> MeshDescriptor: @@ -211,25 +214,40 @@ def simple_mesh() -> MeshDescriptor: assert all(len(row) == 2 for row in e2v_arr) e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 4}, + skip_value=None, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 4}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 4}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="simple_mesh", num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 4, has_skip_values=False - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 4, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 4, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -287,25 +305,40 @@ def skip_value_mesh() -> MeshDescriptor: dtype=gtx.IndexType, ) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 5}, + skip_value=common._DEFAULT_SKIP_VALUE, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 3}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 3}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="skip_value_mesh", num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 5, has_skip_values=True - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 3, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 3, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a5453151e6..1a51e3667d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -89,7 +89,7 @@ def testee(a: cases.VField) -> cases.EField: cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: a[unstructured_case.offset_provider["E2V"].table[:, 0]], + ref=lambda a: a[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -115,16 +115,16 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_flat, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_intermediate_result, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], comparison=lambda inp, tmp: np.all(inp == tmp), ) @@ -132,8 +132,8 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) @@ -583,11 +583,11 @@ def testee(a: cases.VField) -> cases.VField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0)[ - unstructured_case.offset_provider["V2E"].table + np.sum(a[unstructured_case.offset_provider["E2V"].ndarray], axis=1, initial=0)[ + unstructured_case.offset_provider["V2E"].ndarray ], axis=1, - where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, + where=unstructured_case.offset_provider["V2E"].ndarray != common._DEFAULT_SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) @@ -606,8 +606,8 @@ def testee(inp: cases.EField) -> cases.EField: unstructured_case, testee, ref=lambda inp: np.sum( - np.sum(inp[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table + np.sum(inp[unstructured_case.offset_provider["V2E"].ndarray], axis=1)[ + unstructured_case.offset_provider["E2V"].ndarray ], axis=1, ), @@ -627,8 +627,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField unstructured_case, testee, ref=lambda a, b: [ - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1), - np.sum(b[unstructured_case.offset_provider["V2E"].table], axis=1), + np.sum(a[unstructured_case.offset_provider["V2E"].ndarray], axis=1), + np.sum(b[unstructured_case.offset_provider["V2E"].ndarray], axis=1), ], comparison=lambda a, tmp: (np.all(a[0] == tmp[0]), np.all(a[1] == tmp[1])), ) @@ -649,11 +649,11 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: unstructured_case, reduce_tuple_element, ref=lambda e, v: np.sum( - e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, + e[v2e.ndarray] + np.tile(v, (v2e.shape[1], 1)).T, axis=1, initial=0, - where=v2e.table != common._DEFAULT_SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table[:, 0]], + where=v2e.ndarray != common._DEFAULT_SKIP_VALUE, + )[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -780,7 +780,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) return tmp - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 37f4ee2cd1..33832fb5f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -33,11 +33,11 @@ def testee( ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify( unstructured_case, testee, @@ -57,7 +57,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 return neighbor_sum(inp, axis=V2EDim) inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) cases.verify( @@ -65,7 +65,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 testee, inp, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), + ref=np.sum(unstructured_case.offset_provider["V2E"].ndarray, axis=1), ) @@ -76,7 +76,7 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: return inp(V2E) out = unstructured_case.as_field( - [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].ndarray) ) inp = cases.allocate(unstructured_case, testee, "inp")() cases.verify( @@ -84,5 +84,5 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: testee, inp, out=out, - ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].ndarray], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 29966c30ad..7648d34db7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -52,7 +52,7 @@ def testee(edge_f: cases.EField) -> cases.VField: inp = cases.allocate(unstructured_case, testee, "edge_f", strategy=strategy)() out = cases.allocate(unstructured_case, testee, cases.RETURN)() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray ref = np.max( inp.asnumpy()[v2e_table], axis=1, @@ -69,7 +69,7 @@ def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) return out - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, minover, @@ -106,7 +106,7 @@ def reduction_ke_field( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray edge_f = cases.allocate(unstructured_case, fop, "edge_f")() @@ -157,7 +157,7 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() @@ -190,7 +190,7 @@ def reduce_expr(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduce_expr(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, fencil, @@ -210,7 +210,7 @@ def test_reduction_with_common_expression(unstructured_case): def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, @@ -226,7 +226,7 @@ def test_reduction_expression_with_where(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -255,7 +255,7 @@ def test_reduction_expression_with_where_and_tuples(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -284,7 +284,7 @@ def test_reduction_expression_with_where_and_scalar(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 11e28de9e1..66c56c4827 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -90,7 +90,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].ndarray[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 3fc4ed9945..5e3a2fcd14 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -248,11 +248,14 @@ def test_can_deref(program_processor, stencil): program_processor, validate = program_processor Node = gtx.Dimension("Node") + NeighDim = gtx.Dimension("Neighbor", kind=gtx.DimensionKind.LOCAL) inp = gtx.as_field([Node], np.ones((1,), dtype=np.int32)) out = gtx.as_field([Node], np.asarray([0], dtype=inp.dtype)) - no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[-1]]), Node, Node, 1) + no_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[-1]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -264,7 +267,9 @@ def test_can_deref(program_processor, stencil): if validate: assert np.allclose(out.asnumpy(), -1.0) - a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) + a_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[0]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -277,37 +282,6 @@ def test_can_deref(program_processor, stencil): assert np.allclose(out.asnumpy(), 1.0) -# def test_can_deref_lifted(program_processor): -# program_processor, validate = program_processor - -# Neighbor = offset("Neighbor") -# Node = gtx.Dimension("Node") - -# @fundef -# def _can_deref(inp): -# shifted = shift(Neighbor, 0)(inp) -# return if_(can_deref(shifted), 1, -1) - -# inp = gtx.as_field([Node], np.zeros((1,))) -# out = gtx.as_field([Node], np.asarray([0])) - -# no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[None]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": no_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), -1.0) - -# a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": a_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), 1.0) - - @pytest.mark.parametrize( "input_value, dtype, np_dtype", [ diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 69786b323b..7bde55bfd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor +from gt4py.next.iterator.embedded import StridedConnectivityField LocA = gtx.Dimension("LocA") @@ -21,8 +22,10 @@ LocB = gtx.Dimension("LocB") # unused LocA2LocAB = offset("O") -LocA2LocAB_offset_provider = gtx.StridedNeighborOffsetProvider( - origin_axis=LocA, neighbor_axis=LocAB, max_neighbors=2, has_skip_values=False +LocA2LocAB_offset_provider = StridedConnectivityField( + domain_dims=(LocA, gtx.Dimension("Dummy", kind=gtx.DimensionKind.LOCAL)), + codomain_dim=LocAB, + max_neighbors=2, ) @@ -41,7 +44,7 @@ def test_strided_offset_provider(program_processor): program_processor, validate = program_processor LocA_size = 2 - max_neighbors = LocA2LocAB_offset_provider.max_neighbors + max_neighbors = LocA2LocAB_offset_provider.__gt_type__().max_neighbors LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index eb59c77201..6c6ca7e4bc 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -11,7 +11,6 @@ import numpy as np import pytest - pytest.importorskip("atlas4py") from gt4py import next as gtx @@ -22,20 +21,17 @@ exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( + E2V, + V2E, + E2VDim, + Edge, + V2EDim, + Vertex, assert_close, nabla_setup, ) -Vertex = gtx.Dimension("Vertex") -Edge = gtx.Dimension("Edge") -V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) -E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) - -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) - - @gtx.field_operator def compute_zavgS( pp: gtx.Field[[Vertex], float], S_M: gtx.Field[[Edge], float] @@ -67,21 +63,19 @@ def pnabla( def test_ffront_compute_zavgS(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + setup = nabla_setup(allocator=allocator) zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - - compute_zavgS.with_backend(exec_alloc_descriptor)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + compute_zavgS.with_backend( + None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor + )( + setup.input_field, + setup.S_fields[0], + out=zavgS, + offset_provider={"E2V": setup.edges2node_connectivity}, ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) @@ -89,27 +83,23 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): def test_ffront_nabla(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) + setup = nabla_setup(allocator=allocator) pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - v2e = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 - ) - - pnabla.with_backend(exec_alloc_descriptor)( - pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} + pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)( + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + out=(pnabla_MXX, pnabla_MYY), + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) # TODO this check is not sensitive enough, need to implement a proper numpy reference! diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 8d7324f438..6a5865134d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -20,6 +20,18 @@ functionspace, ) +from gt4py import next as gtx +from gt4py.next.iterator import atlas_utils + + +Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) + +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) + def assert_close(expected, actual): assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) @@ -33,9 +45,10 @@ def _default_config(): config["angle"] = 20.0 return config - def __init__(self, *, grid=StructuredGrid("O32"), config=None): + def __init__(self, *, allocator, grid=StructuredGrid("O32"), config=None): if config is None: config = self._default_config() + self.allocator = allocator mesh = StructuredMeshGenerator(config).generate(grid) fs_edges = functionspace.EdgeColumns(mesh, halo=1) @@ -55,12 +68,22 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=None): self.edges_per_node = edges_per_node @property - def edges2node_connectivity(self): - return self.mesh.edges.node_connectivity + def edges2node_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Edge: self.edges_size, E2VDim: 2}, + codomain=Vertex, + data=atlas_utils.AtlasTable(self.mesh.edges.node_connectivity).asnumpy(), + allocator=self.allocator, + ) @property - def nodes2edge_connectivity(self): - return self.mesh.nodes.edge_connectivity + def nodes2edge_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Vertex: self.nodes_size, V2EDim: self.edges_per_node}, + codomain=Edge, + data=atlas_utils.AtlasTable(self.mesh.nodes.edge_connectivity).asnumpy(), + allocator=self.allocator, + ) @property def nodes_size(self): @@ -75,16 +98,16 @@ def _is_pole_edge(e, edge_flags): return Topology.check(edge_flags[e], Topology.POLE) @property - def is_pole_edge_field(self): + def is_pole_edge_field(self) -> gtx.Field: edge_flags = np.array(self.mesh.edges.flags()) pole_edge_field = np.zeros((self.edges_size,), dtype=bool) for e in range(self.edges_size): pole_edge_field[e] = self._is_pole_edge(e, edge_flags) - return pole_edge_field + return gtx.as_field([Edge], pole_edge_field, allocator=self.allocator) @property - def sign_field(self): + def sign_field(self) -> gtx.Field: node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) edge_flags = np.array(self.mesh.edges.flags()) @@ -100,10 +123,10 @@ def sign_field(self): node2edge_sign[jnode, jedge] = -1.0 if self._is_pole_edge(iedge, edge_flags): node2edge_sign[jnode, jedge] = 1.0 - return node2edge_sign + return gtx.as_field([Vertex, V2EDim], node2edge_sign, allocator=self.allocator) @property - def S_fields(self): + def S_fields(self) -> tuple[gtx.Field, gtx.Field]: S = np.array(self.mesh.edges.field("dual_normals"), copy=False) S_MXX = np.zeros((self.edges_size)) S_MYY = np.zeros((self.edges_size)) @@ -124,10 +147,12 @@ def S_fields(self): assert math.isclose(min(S_MYY), -2001577.7946404363) assert math.isclose(max(S_MYY), 2001577.7946404363) - return S_MXX, S_MYY + return gtx.as_field([Edge], S_MXX, allocator=self.allocator), gtx.as_field( + [Edge], S_MYY, allocator=self.allocator + ) @property - def vol_field(self): + def vol_field(self) -> gtx.Field: rpi = 2.0 * math.asin(1.0) radius = 6371.22e03 deg2rad = 2.0 * rpi / 360.0 @@ -142,10 +167,10 @@ def vol_field(self): # VOL(min/max): 57510668192.214096 851856184496.32886 assert_close(57510668192.214096, min(vol)) assert_close(851856184496.32886, max(vol)) - return vol + return gtx.as_field([Vertex], vol, allocator=self.allocator) @property - def input_field(self): + def input_field(self) -> gtx.Field: klevel = 0 MXX = 0 MYY = 1 @@ -200,4 +225,5 @@ def input_field(self): assert_close(0.0000000000000000, min(rzs)) assert_close(1965.4980340735883, max(rzs)) - return rzs[:, klevel] + + return gtx.as_field([Vertex], rzs[:, klevel], allocator=self.allocator) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 3db4497910..4487681abf 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -111,25 +111,18 @@ def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): @pytest.mark.requires_atlas def test_compute_zavgS(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + setup = nabla_setup(allocator=None) zavgS = gtx.as_field([Edge], np.zeros((setup.edges_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS_fencil, program_processor, setup.edges_size, zavgS, - pp, - S_MXX, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[0], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -141,9 +134,9 @@ def test_compute_zavgS(program_processor): program_processor, setup.edges_size, zavgS, - pp, - S_MYY, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[1], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -158,29 +151,21 @@ def compute_zavgS2_fencil(n_edges, out, pp, S_M): @pytest.mark.requires_atlas def test_compute_zavgS2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - - S = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) + setup = nabla_setup(allocator=None) zavgS = ( gtx.as_field([Edge], np.zeros((setup.edges_size))), gtx.as_field([Edge], np.zeros((setup.edges_size))), ) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS2_fencil, program_processor, setup.edges_size, zavgS, - pp, - S, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields, + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -195,34 +180,27 @@ def test_compute_zavgS2(program_processor): def test_nabla(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, + setup.input_field, S_MXX, S_MYY, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -245,33 +223,24 @@ def nabla2(n_nodes, out, pp, S, sign, vol): @pytest.mark.requires_atlas def test_nabla2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) - vol = gtx.as_field([Vertex], setup.vol_field) + setup = nabla_setup(allocator=None) pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla2, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, - S_M, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -325,36 +294,29 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ def test_nabla_sign(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla_sign, program_processor, setup.nodes_size, pnabla_MXX, pnabla_MYY, - pp, + setup.input_field, S_MXX, S_MYY, - vol, + setup.vol_field, gtx.index_field(Vertex), - is_pole_edge, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.is_pole_edge_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fdc6a77a1..ac7ce9e544 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -38,9 +38,13 @@ V2VDim, Vertex, c2e_arr, + c2e_conn, e2v_arr, + e2v_conn, v2e_arr, + v2e_conn, v2v_arr, + v2v_conn, ) from next_tests.unit_tests.conftest import program_processor, run_processor @@ -89,7 +93,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -111,7 +115,7 @@ def test_map_neighbors(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -134,7 +138,7 @@ def test_map_make_const_list(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -157,8 +161,8 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo inp, out=out, offset_provider={ - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + "E2V": e2v_conn, + "C2E": c2e_conn, }, ) if validate: @@ -185,7 +189,7 @@ def test_sparse_input_field(program_processor): non_sparse, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: @@ -208,8 +212,8 @@ def test_sparse_input_field_v2v(program_processor): inp, out=out, offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2V": v2v_conn, + "V2E": v2e_conn, }, ) @@ -235,7 +239,7 @@ def test_slice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -259,7 +263,7 @@ def test_slice_twice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -284,7 +288,7 @@ def test_shift_sliced_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -309,7 +313,7 @@ def test_slice_shifted_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -337,7 +341,7 @@ def test_lift(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -360,7 +364,7 @@ def test_shift_sparse_input_field(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -393,8 +397,8 @@ def test_shift_sparse_input_field2(program_processor): out2 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "E2V": e2v_conn, + "V2E": v2e_conn, } domain = {Vertex: range(0, 9)} @@ -448,7 +452,7 @@ def test_sparse_shifted_stencil_reduce(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 82c91a5e74..50db24b880 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -49,6 +49,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) + v2v_arr = np.array( [ [1, 3, 2, 6], @@ -64,6 +66,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) + e2v_arr = np.array( [ [0, 1], @@ -88,6 +92,7 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) # order east, north, west, south (counter-clock wise) v2e_arr = np.array( @@ -104,3 +109,5 @@ ], dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) + +v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index ca66b45d6d..f1269f1ed8 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -14,11 +14,11 @@ import pytest import gt4py.next as gtx -from gt4py.next import backend +from gt4py.next import backend, common +from gt4py.next.embedded import nd_array_field from gt4py.next.iterator import runtime from gt4py.next.program_processors import program_formatter - import next_tests @@ -97,12 +97,21 @@ def run_processor( @dataclasses.dataclass -class DummyConnectivity: +class DummyConnectivity(common.Connectivity): max_neighbors: int has_skip_values: int - origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") - neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") - index_type: type[int] = int + source_dim: gtx.Dimension = gtx.Dimension("dummy_origin") + codomain: gtx.Dimension = gtx.Dimension("dummy_neighbor") + + +def nd_array_implementation_params(): + for xp in nd_array_field._nd_array_implementations: + if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: + yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) + else: + yield pytest.param(xp, id=xp.__name__) + - def mapped_index(_, __) -> int: - return 0 +@pytest.fixture(params=nd_array_implementation_params()) +def nd_array_implementation(request): + yield request.param diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 063e79d92e..9dde5bb40a 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex +from gt4py.next.common import Dimension, Domain, NamedIndex, NamedRange, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -28,19 +28,6 @@ D2 = Dimension("D2") -def nd_array_implementation_params(): - for xp in nd_array_field._nd_array_implementations: - if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: - yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) - else: - yield pytest.param(xp, id=xp.__name__) - - -@pytest.fixture(params=nd_array_implementation_params()) -def nd_array_implementation(request): - yield request.param - - @pytest.fixture( params=[ operator.add, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index dcc3a306f2..a91dbeb608 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -31,12 +31,10 @@ # 0 --0-- 1 --1-- 2 e2v_arr = np.array([[0, 1], [1, 2]]) -e2v_conn = gtx.NeighborTableOffsetProvider( - table=e2v_arr, - origin_axis=E, - neighbor_axis=V, - max_neighbors=2, - has_skip_values=False, +e2v_conn = gtx.as_connectivity( + domain={E: 2, E2VDim: 2}, + codomain=V, + data=e2v_arr, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 1f08362f4f..13e8637d1a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -10,18 +10,22 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.iterator.builtins import deref from gt4py.next.iterator.runtime import CartesianDomain, UnstructuredDomain, _deduce_domain, fundef -from next_tests.unit_tests.conftest import DummyConnectivity - @fundef def foo(inp): return deref(inp) -connectivity = DummyConnectivity(max_neighbors=0, has_skip_values=True) +connectivity = common.ConnectivityType( + domain=[gtx.Dimension("dummy_origin"), gtx.Dimension("dummy_neighbor")], + codomain=gtx.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE, + dtype=None, +) def test_deduce_domain(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7b6214fb1b..65a5b5888d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -218,11 +218,11 @@ def expression_test_cases(): @pytest.mark.parametrize("test_case", expression_test_cases()) def test_expression_type(test_case): mesh = simple_mesh() - offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} + offset_provider_type = {**mesh.offset_provider_type, "Ioff": IDim, "Joff": JDim, "Koff": KDim} testee, expected_type = test_case result = itir_type_inference.infer( - testee, offset_provider=offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=offset_provider_type, allow_undeclared_symbols=True ) assert result.type == expected_type @@ -231,14 +231,16 @@ def test_adhoc_polymorphism(): func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.type == ts.TupleType(types=[bool_type, int_type]) def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) - result = itir_type_inference.infer(testee, offset_provider={}) + result = itir_type_inference.infer(testee, offset_provider_type={}) assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type @@ -253,7 +255,7 @@ def test_late_offset_axis(): testee = im.call(func)(im.ensure_offset("V2E")) result = itir_type_inference.infer( - testee, offset_provider=mesh.offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=mesh.offset_provider_type, allow_undeclared_symbols=True ) assert result.type == it_on_e_of_e_type @@ -265,7 +267,9 @@ def test_cast_first_arg_inference(): testee = im.call("cast_")( im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.args[0].type == int_type assert result.type == float64_type @@ -291,7 +295,7 @@ def test_cartesian_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -336,7 +340,7 @@ def test_unstructured_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[Vertex, KDim]), @@ -384,7 +388,7 @@ def test_function_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -429,7 +433,7 @@ def test_fencil_with_nb_field_input(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type @@ -456,7 +460,7 @@ def test_program_tuple_setat_short_target(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.TupleType) @@ -487,7 +491,7 @@ def test_program_setat_without_domain(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.DeferredType) @@ -512,7 +516,9 @@ def test_if_stmt(): false_branch=[], ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field @@ -522,7 +528,7 @@ def test_as_fieldop_without_domain(): im.ref("inp", float_i_field) ) result = itir_type_inference.infer( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert result.type == ts.DeferredType(constraint=ts.FieldType) assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index e04856b75f..f4ea2d7fe1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -21,7 +21,7 @@ @pytest.fixture -def offset_provider(request): +def offset_provider_type(request): return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)} @@ -137,7 +137,7 @@ def common_expr(): assert actual == expected -def test_if_can_deref_no_extraction(offset_provider): +def test_if_can_deref_no_extraction(offset_provider_type): # Test that a subexpression only occurring in one branch of an `if_` is not moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -157,11 +157,11 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_can_deref_eligible_extraction(offset_provider): +def test_if_can_deref_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in both branches of an `if_` is moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -178,11 +178,11 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_eligible_extraction(offset_provider): +def test_if_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in the condition of an `if_` is moved outside the # if statement. @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 141091b450..817c06e8f0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -14,11 +14,12 @@ from gt4py import eve from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next import constructors from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import infer_domain from gt4py.next.iterator.ir_utils import domain_utils from gt4py.next.common import Dimension -from gt4py.next import common, NeighborTableOffsetProvider +from gt4py.next import common from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next import utils @@ -29,6 +30,7 @@ KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) Edge = common.Dimension(value="Edge", kind=common.DimensionKind.HORIZONTAL) +E2VDim = common.Dimension(value="E2V", kind=common.DimensionKind.LOCAL) @pytest.fixture @@ -39,11 +41,10 @@ def offset_provider(): @pytest.fixture def unstructured_offset_provider(): return { - "E2V": NeighborTableOffsetProvider( - np.array([[0, 1]], dtype=np.int32), - Edge, - Vertex, - 2, + "E2V": constructors.as_connectivity( + domain={Edge: 1, E2VDim: 2}, + codomain=Vertex, + data=np.array([[0, 1]], dtype=np.int32), ) } diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index b5b9a62009..168e9490e0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -13,6 +13,7 @@ from gt4py.next.iterator.transforms import fuse_as_fieldop from gt4py.next.type_system import type_specifications as ts + IDim = gtx.Dimension("IDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) @@ -30,7 +31,7 @@ def test_trivial(): d, )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -40,7 +41,7 @@ def test_trivial_literal(): testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -65,7 +66,7 @@ def test_tuple_arg(): d, )() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -85,7 +86,7 @@ def test_symref_used_twice(): d, )("inp1", "inp2") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -100,7 +101,7 @@ def test_no_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee @@ -132,6 +133,6 @@ def test_partial_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 23f62842c4..9d51dc4f33 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -52,7 +52,7 @@ def test_trivial(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -87,7 +87,7 @@ def test_trivial_let(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -128,7 +128,7 @@ def test_top_level_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -186,7 +186,7 @@ def test_nested_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 7c991fb9a8..77d3323fb4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -8,16 +8,16 @@ from gt4py import next as gtx from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.prune_casts import PruneCasts from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts def test_prune_casts_simple(): x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) @@ -32,7 +32,7 @@ def test_prune_casts_fieldop(): im.cast_as_fieldop("float64")(x_ref), im.cast_as_fieldop("float64")(y_ref), ) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.op_as_fieldop("plus")( im.cast_as_fieldop("float64")(x_ref), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 28bd88b853..0760247996 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -11,11 +11,20 @@ import pytest from gt4py.eve.utils import UIDs +from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags -from next_tests.unit_tests.conftest import DummyConnectivity + +def dummy_connectivity_type(max_neighbors: int, has_skip_values: bool): + return common.NeighborConnectivityType( + domain=[common.Dimension("dummy_origin"), common.Dimension("dummy_neighbor")], + codomain=common.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + dtype=None, + max_neighbors=max_neighbors, + ) @pytest.fixture(params=[True, False]) @@ -67,7 +76,7 @@ def reduction_if(): ], ) def test_get_partial_offsets(reduction, request): - offset_provider = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} + offset_provider_type = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} partial_offsets = _get_partial_offset_tags(request.getfixturevalue(reduction).args) assert set(partial_offsets) == {"Dim"} @@ -108,63 +117,73 @@ def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): def test_basic(basic_reduction, has_skip_values): expected = _expected(basic_reduction, "Dim", 3, has_skip_values) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=3, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(basic_reduction, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply(basic_reduction, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, has_skip_values): expected = _expected(reduction_with_shift_on_second_arg, "Dim", 1, has_skip_values, 1) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=1, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(reduction_with_shift_on_second_arg, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=1, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply( + reduction_with_shift_on_second_arg, offset_provider_type=offset_provider_type + ) assert actual == expected def test_reduction_with_if(reduction_if): expected = _expected(reduction_if, "Dim", 2, False) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} - actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + offset_provider_type = {"Dim": dummy_connectivity_type(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "IrrelevantDim": DummyConnectivity( + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "IrrelevantDim": dummy_connectivity_type( max_neighbors=1, has_skip_values=True ), # different max_neighbors and skip value to trigger error } actual = UnrollReduce.apply( - reduction_with_irrelevant_full_shift, offset_provider=offset_provider + reduction_with_irrelevant_full_shift, offset_provider_type=offset_provider_type ) assert actual == expected @pytest.mark.parametrize( - "offset_provider", + "offset_provider_type", [ { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=3, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=3, has_skip_values=True), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=True), }, ], ) -def test_reduction_with_incompatible_shifts(reduction_with_incompatible_shifts, offset_provider): - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), +def test_reduction_with_incompatible_shifts( + reduction_with_incompatible_shifts, offset_provider_type +): + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), } with pytest.raises(RuntimeError, match="incompatible"): - UnrollReduce.apply(reduction_with_incompatible_shifts, offset_provider=offset_provider) + UnrollReduce.apply( + reduction_with_incompatible_shifts, offset_provider_type=offset_provider_type + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 1a86f7b0f8..97591122e5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -21,7 +21,7 @@ def test_funcall_to_op(): ) actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual @@ -32,7 +32,7 @@ def test_unapplied_funcall_to_function_object(): expected = gtfn_ir.SymRef(id="plus") actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 329b2814d2..62d88d9f0a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -11,6 +11,7 @@ import ctypes import unittest import unittest.mock +from unittest.mock import patch import numpy as np import pytest @@ -20,19 +21,15 @@ from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import ( - E2V, - cartesian_case, - unstructured_case, -) +from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, mesh_descriptor, ) -from unittest.mock import patch from . import pytestmark + dace = pytest.importorskip("dace") @@ -151,14 +148,14 @@ def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): # check that test connectivities are allocated on host memory # this is an assumption to test that fast_call cannot be used for gpu tests - assert isinstance(connectivity_E2V.table, np.ndarray) + assert isinstance(connectivity_E2V.ndarray, np.ndarray) @gtx.field_operator def testee(a: cases.VField) -> cases.EField: return a(E2V[0]) (a,), kwfields = cases.get_default_data(unstructured_case, testee) - numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + numpy_ref = lambda a: a[connectivity_E2V.ndarray[:, 0]] mock_fast_call, mock_construct_args = make_mocks(monkeypatch) @@ -194,12 +191,11 @@ def verify_testee(offset_provider): # Here we copy the connectivity to gpu memory, and resuse the same cupy array # on multiple program calls, in order to ensure that fast_call is used. offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider( - table=cp.asarray(connectivity_E2V.table), - origin_axis=connectivity_E2V.origin_axis, - neighbor_axis=connectivity_E2V.neighbor_axis, - max_neighbors=connectivity_E2V.max_neighbors, - has_skip_values=connectivity_E2V.has_skip_values, + "E2V": gtx.as_connectivity( + domain=connectivity_E2V.domain, + codomain=connectivity_E2V.codomain, + data=cp.asarray(connectivity_E2V.ndarray), + skip_value=connectivity_E2V.skip_value, ) } diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e0c0c3fa4e..9c52ea81c3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, constructors from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -50,13 +50,7 @@ "IDim": IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() -SIMPLE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SIMPLE_MESH.offset_provider | CARTESIAN_OFFSETS -) SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() -SKIP_VALUE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SKIP_VALUE_MESH.offset_provider | CARTESIAN_OFFSETS -) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, @@ -83,20 +77,20 @@ def make_mesh_symbols(mesh: MeshDescriptor): __vertices_size_0=mesh.num_vertices, __vertices_stride_0=1, __connectivity_C2E_size_0=mesh.num_cells, - __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, __connectivity_C2V_size_0=mesh.num_cells, - __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, __connectivity_E2V_size_0=mesh.num_edges, - __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, __connectivity_V2E_size_0=mesh.num_vertices, - __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) @@ -1018,14 +1012,14 @@ def test_gtir_connectivity_shift(): CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) EDGE_OFFSET_FTYPE = ts.FieldType(dims=[Edge], dtype=SIZE_TYPE) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) ev = np.random.rand(SIMPLE_MESH.num_edges, SIMPLE_MESH.num_vertices) - ref = ev[connectivity_C2E.table[:, C2E_neighbor_idx], :][ - :, connectivity_E2V.table[:, E2V_neighbor_idx] + ref = ev[connectivity_C2E.ndarray[:, C2E_neighbor_idx], :][ + :, connectivity_E2V.ndarray[:, E2V_neighbor_idx] ] for i, stencil in enumerate( @@ -1053,7 +1047,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1062,8 +1056,8 @@ def test_gtir_connectivity_shift(): ev, c2e_offset=np.full(SIMPLE_MESH.num_cells, C2E_neighbor_idx, dtype=np.int32), e2v_offset=np.full(SIMPLE_MESH.num_edges, E2V_neighbor_idx, dtype=np.int32), - connectivity_C2E=connectivity_C2E.table, - connectivity_E2V=connectivity_E2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __ce_field_size_0=SIMPLE_MESH.num_cells, @@ -1114,15 +1108,17 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) - ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] + ref = e[ + connectivity_V2E.ndarray[connectivity_E2V.ndarray[:, E2V_neighbor_idx], V2E_neighbor_idx] + ] # new empty output field e_out = np.empty_like(e) @@ -1130,8 +1126,8 @@ def test_gtir_connectivity_shift_chain(): sdfg( e, e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, + connectivity_E2V=connectivity_E2V.ndarray, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __edges_out_size_0=SIMPLE_MESH.num_edges, @@ -1174,30 +1170,30 @@ def test_gtir_neighbors_as_input(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.shape[1]) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) v_ref = [ functools.reduce(lambda x, y: x + y, v2e_values + e[v2e_neighbors], init_value) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1210,7 +1206,7 @@ def test_gtir_neighbors_as_output(): gtx_common.GridType.UNSTRUCTURED, ranges={ Vertex: (0, "nvertices"), - V2EDim: (0, SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors), + V2EDim: (0, SIMPLE_MESH.offset_provider_type["V2E"].max_neighbors), }, ) vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) @@ -1232,9 +1228,9 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) @@ -1243,7 +1239,7 @@ def test_gtir_neighbors_as_output(): sdfg( e, v2e_field, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, @@ -1251,7 +1247,7 @@ def test_gtir_neighbors_as_output(): __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, ) - assert np.allclose(v2e_field, e[connectivity_V2E.table]) + assert np.allclose(v2e_field, e[connectivity_V2E.ndarray]) def test_gtir_reduce(): @@ -1278,13 +1274,13 @@ def test_gtir_reduce(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) v_ref = [ functools.reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1305,7 +1301,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1313,7 +1309,7 @@ def test_gtir_reduce(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) @@ -1344,7 +1340,7 @@ def test_gtir_reduce_with_skip_values(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SKIP_VALUE_MESH.num_edges) @@ -1354,7 +1350,7 @@ def test_gtir_reduce_with_skip_values(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1375,7 +1371,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1383,7 +1379,7 @@ def test_gtir_reduce_with_skip_values(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), ) @@ -1394,10 +1390,10 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1409,7 +1405,7 @@ def test_gtir_reduce_dot_product(): ), init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field) ] testee = gtir.Program( @@ -1448,17 +1444,17 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1500,14 +1496,14 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) for use_sparse in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1525,19 +1521,19 @@ def test_gtir_reduce_with_cond_neighbors(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( np.bool_(use_sparse), v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1631,9 +1627,9 @@ def test_gtir_let_lambda_with_connectivity(): C2V_neighbor_idx = 2 cell_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Cell: (0, "ncells")}) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_C2V = SIMPLE_MESH_OFFSET_PROVIDER["C2V"] + connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] assert isinstance(connectivity_C2V, gtx_common.NeighborTable) testee = gtir.Program( @@ -1669,22 +1665,22 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) c = np.empty(SIMPLE_MESH.num_cells) ref = ( - e[connectivity_C2E.table[:, C2E_neighbor_idx]] - + v[connectivity_C2V.table[:, C2V_neighbor_idx]] + e[connectivity_C2E.ndarray[:, C2E_neighbor_idx]] + + v[connectivity_C2V.ndarray[:, C2V_neighbor_idx]] ) sdfg( cells=c, edges=e, vertices=v, - connectivity_C2E=connectivity_C2E.table, - connectivity_C2V=connectivity_C2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_C2V=connectivity_C2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 6e9dfa3d64..0998ab8eab 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -11,10 +11,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, common, float32 -from gt4py.next.program_processors.runners import roundtrip - -from next_tests.integration_tests import cases +from gt4py.next import allocators as next_allocators, common I = gtx.Dimension("I") @@ -154,3 +151,12 @@ def test_field_wrong_origin(): @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index(): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), aligned_index=[I, 0]) + + +@pytest.mark.parametrize( + "data, skip_value", + [([0, 1, 2], None), ([0, 1, common._DEFAULT_SKIP_VALUE], common._DEFAULT_SKIP_VALUE)], +) +def test_as_connectivity(nd_array_implementation, data, skip_value): + testee = gtx.as_connectivity([I], J, nd_array_implementation.array(data)) + assert testee.skip_value is skip_value From 3fb206e46ceecf07b7ef6c668239d62d79028503 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Nov 2024 10:53:19 +0100 Subject: [PATCH 15/43] feat[next][dace]: Symbolic domain without dace array offsets (#1735) Add support for field operator domain with symbolic shape, with dimension extent in non zero-based range. --- .../runners/dace_common/utility.py | 10 +- .../gtir_builtin_translators.py | 127 ++++++++++----- .../runners/dace_fieldview/gtir_dataflow.py | 100 +++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 148 +++++++++++++----- .../runners/dace_fieldview/utility.py | 11 +- .../dace_tests/test_gtir_to_sdfg.py | 123 +++++++++++++-- 6 files changed, 367 insertions(+), 152 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 29395a30c1..3e96ef3cec 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Optional, Sequence +from typing import Final, Literal, Optional, Sequence import dace @@ -51,12 +51,16 @@ def connectivity_identifier(name: str) -> str: return f"connectivity_{name}" +def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: + return f"__{field_name}_{sym}_{axis}" + + def field_size_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_size_{axis}" + return field_symbol_name(field_name, axis, "size") def field_stride_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_stride_{axis}" + return field_symbol_name(field_name, axis, "stride") def is_field_symbol(name: str) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69aedf44d6..60dcd8ddc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace import dace.subsets as sbs @@ -33,6 +33,34 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg +def _get_domain_indices( + dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None +) -> sbs.Indices: + """ + Helper function to construct the list of indices for a field domain, applying + an optional offset in each dimension as start index. + + Args: + dims: The field dimensions. + offsets: The range start index in each dimension. + + Returns: + A list of indices for field access in dace arrays. As this list is returned + as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before + being used in memlet subset because ranges are better supported throughout DaCe. + """ + index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] + if offsets is None: + return sbs.Indices(index_variables) + else: + return sbs.Indices( + [ + index - offset if offset != 0 else index + for index, offset in zip(index_variables, offsets, strict=True) + ] + ) + + @dataclasses.dataclass(frozen=True) class FieldopData: """ @@ -45,42 +73,59 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. + offset: List of index offsets, in each dimension, when the dimension range + does not start from zero; assume zero offset, if not set. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType + offset: Optional[list[dace.symbolic.SymExpr]] + + def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: + """Create a copy of this data descriptor with a different access node.""" + assert data_node != self.dc_node + return FieldopData(data_node, self.gt_type, self.offset) def get_local_view( self, domain: FieldopDomain ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: - """Helper method to access a field in local view, given a field operator domain.""" + """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain + domain_dims = [dim for dim, _, _ in domain] + domain_indices = _get_domain_indices(domain_dims) + it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) + for dim, index in zip(domain_dims, domain_indices) } + field_domain = [ + (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + for i, dim in enumerate(self.gt_type.dims) + ] local_dims = [ dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL ] - if len(local_dims) == 0: return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + self.dc_node, self.gt_type.dtype, field_domain, it_indices ) elif len(local_dims) == 1: field_dtype = itir_ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) - field_dims = [ - dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + field_domain = [ + (dim, offset) + for dim, offset in field_domain + if dim.kind != gtx_common.DimensionKind.LOCAL ] - return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + return gtir_dataflow.IteratorExpr( + self.dc_node, field_dtype, field_domain, it_indices + ) else: raise ValueError( @@ -155,9 +200,9 @@ def _parse_fieldop_arg( return arg.get_local_view(domain) -def _get_field_shape( +def _get_field_layout( domain: FieldopDomain, -) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: """ Parse the field operator domain and generates the shape of the result field. @@ -174,11 +219,14 @@ def _get_field_shape( domain: The field operator domain. Returns: - A tuple of two lists: the list of field dimensions and the list of dace - array sizes in each dimension. + A tuple of three lists containing: + - the domain dimensions + - the domain offset in each dimension + - the domain size in each dimension """ - domain_dims, _, domain_ubs = zip(*domain) - return list(domain_dims), list(domain_ubs) + domain_dims, domain_lbs, domain_ubs = zip(*domain) + domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] + return list(domain_dims), list(domain_lbs), domain_sizes def _create_temporary_field( @@ -189,7 +237,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - field_dims, field_shape = _get_field_shape(domain) + field_dims, field_offset, field_shape = _get_field_layout(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -197,6 +245,7 @@ def _create_temporary_field( assert isinstance(node_type.dtype.element_type, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + field_offset.extend(output_desc.offset) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) @@ -215,7 +264,11 @@ def _create_temporary_field( assert dataflow_output.result.gt_dtype.offset_type is not None field_dims.append(dataflow_output.result.gt_dtype.offset_type) - return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -285,7 +338,8 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + domain_dims, domain_offsets, _ = zip(*domain) + domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] @@ -350,10 +404,8 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - field_dims, field_shape = _get_field_shape(domain) - field_subset = sbs.Range.from_string( - ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) - ) + output_dims, output_offset, output_shape = _get_field_layout(domain) + output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) assert len(node.args) == 1 scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) @@ -369,26 +421,15 @@ def translate_broadcast_scalar( assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) if len(node.args[0].type.dims) == 0: # zero-dimensional field input_subset = "0" - elif all( - isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) - for dim in scalar_expr.dimensions - if dim not in field_dims - ): - input_subset = ",".join( - dace_gtir_utils.get_map_variable(dim) - if dim in field_dims - else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above - for dim in scalar_expr.dimensions - ) else: - raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + input_subset = scalar_expr.get_memlet_subset(sdfg) input_node = scalar_expr.field gt_dtype = node.args[0].type.dtype else: raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( @@ -400,13 +441,13 @@ def translate_broadcast_scalar( }, inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, input_nodes={input_node.data: input_node}, output_nodes={output_node.data: output_node}, external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) + return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) def translate_if( @@ -467,7 +508,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg.add_temp_transient_like(inner_desc) outer_node = state.add_access(outer) - return FieldopData(outer_node, inner_data.gt_type) + return inner_data.make_copy(outer_node) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -513,7 +554,7 @@ def _get_data_nodes( ) -> FieldopResult: if isinstance(data_type, ts.FieldType): data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.ScalarType): if data_name in sdfg.symbols: @@ -522,7 +563,7 @@ def _get_data_nodes( ) else: data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) @@ -579,7 +620,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type) + return FieldopData(data_node, data_type, offset=None) def translate_make_tuple( @@ -708,7 +749,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type) + return FieldopData(temp_node, node.type, offset=None) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 74142dec66..cfba4d61e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -90,17 +90,42 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. - dimensions: Field domain represented as a sorted list of dimensions, needed - to order the map index variables and dereference an element in the field. + field_domain: Field domain represented as a sorted list of dimensions and offset values, + used to find the position of a map index variable in the memlet subset. The offset + value is either the start index of dimension range or the compile-time value of + a shift expression, or a composition of both, and it must be subtracted to the index + variable when constructing the memlet subset range. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. """ field: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - dimensions: list[gtx_common.Dimension] + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): + raise ValueError(f"Cannot deref iterator {self}.") + + field_desc = self.field.desc(sdfg) + if isinstance(self.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(self.field_domain) + 1 + assert self.gt_dtype.offset_type is not None + field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] + else: + assert len(field_desc.shape) == len(self.field_domain) + field_domain = self.field_domain + + return sbs.Range.from_string( + ",".join( + str(self.indices[dim].value - offset) # type: ignore[union-attr] + if dim in self.indices + else f"0:{size}" + for (dim, offset), size in zip(field_domain, field_desc.shape, strict=True) + ) + ) + class DataflowInputEdge(Protocol): """ @@ -271,8 +296,17 @@ def _add_input_data_edge( src_subset: sbs.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, + src_offset: Optional[list[dace.symbolic.SymExpr]] = None, ) -> None: - edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) + input_subset = ( + src_subset + if src_offset is None + else sbs.Range( + (start - off, stop - off, step) + for (start, stop, step), off in zip(src_subset, src_offset, strict=True) + ) + ) + edge = MemletInputEdge(self.state, src, input_subset, dst_node, dst_conn) self.input_edges.append(edge) def _add_edge( @@ -440,34 +474,21 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc = arg_expr.field.desc(self.sdfg) if isinstance(field_desc, dace.data.Scalar): # deref a zero-dimensional field - assert len(arg_expr.dimensions) == 0 + assert len(arg_expr.field_domain) == 0 assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") # default case: deref a field with one or more dimensions if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - if isinstance(arg_expr.gt_dtype, itir_ts.ListType): - assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 - assert arg_expr.gt_dtype.offset_type is not None - field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] - else: - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_dims = arg_expr.dimensions - - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(field_dims, field_desc.shape) - ) + # when all indices are symbolic expressions, we can perform direct field access through a memlet + field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) + assert len(field_desc.shape) == len(arg_expr.field_domain) + field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] index_connectors = [ IndexConnectorFmt.format(dim=dim.value) for dim, index in field_indices @@ -494,6 +515,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: sbs.Range.from_array(field_desc), deref_node, "field", + src_offset=[offset for (_, offset) in arg_expr.field_domain], ) for dim, index_expr in field_indices: @@ -532,7 +554,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.codomain in it.dimensions + assert any(dim == offset_provider.codomain for dim, _ in it.field_domain) assert offset_provider.source_dim in it.indices origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) @@ -560,10 +582,12 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=node.type, subset=sbs.Range.from_string( ",".join( - it.indices[dim].value # type: ignore[union-attr] + str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + for (dim, offset), size in zip( + it.field_domain, field_desc.shape, strict=True + ) ) ), ) @@ -971,14 +995,13 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim in it.dimensions + assert any(dim == offset_dim for dim, _ in it.field_domain) new_index: SymbolExpr | ValueExpr - assert offset_dim in it.indices index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( - dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, + index_expr.value + offset_expr.value, index_expr.dc_dtype, ) else: @@ -1032,15 +1055,10 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - return IteratorExpr( - field=it.field, - gt_dtype=it.gt_dtype, - dimensions=it.dimensions, - indices={ - dim: (new_index if dim == offset_dim else index) - for dim, index in it.indices.items() - }, - ) + shifted_indices = { + dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() + } + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( self, @@ -1094,7 +1112,7 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.codomain in it.dimensions + assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices @@ -1117,9 +1135,7 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr( - field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices - ) + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 52284edfac..f15287e64c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -16,6 +16,7 @@ import abc import dataclasses +import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -98,9 +99,16 @@ def add_mapped_tasklet( class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" + @abc.abstractmethod + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + """Retrieve the field data descriptor including the domain offset information.""" + ... + @abc.abstractmethod def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the program.""" + """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... @abc.abstractmethod @@ -141,6 +149,15 @@ def _collect_symbols_in_domain_expressions( ) +def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + """ + return ts.TupleType( + types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -157,6 +174,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( + default_factory=lambda: {} + ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -167,6 +187,15 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: return self.offset_provider_type[offset] + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + if isinstance(data_type, ts.FieldType): + domain_offset = self.field_offsets.get(data_node.data, None) + else: + domain_offset = None + return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset) + def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -248,12 +277,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, gt_type, flatten=True - ): + for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name ) ) return tuple_fields @@ -275,7 +302,6 @@ def _add_storage( tuple_name, gt_type.dims ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) - return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): @@ -344,7 +370,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) + return field.make_copy(temp_node) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -405,6 +431,10 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + # Since program field arguments are passed to the SDFG as full-shape arrays, + # there is no offset that needs to be compensated. + assert len(self.field_offsets) == 0 + sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -459,7 +489,7 @@ def visit_SetAt( The SDFG head state, eventually updated if the target write requires a new state. """ - temp_fields = self._visit_expression(stmt.expr, sdfg, state) + source_fields = self._visit_expression(stmt.expr, sdfg, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field @@ -482,17 +512,26 @@ def visit_SetAt( } target_state: Optional[dace.SDFGState] = None - for temp, target in zip(temp_fields, target_fields, strict=True): + for source, target in zip(source_fields, target_fields, strict=True): target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient if isinstance(target.gt_type, ts.FieldType): - subset = ",".join( + target_subset = ",".join( f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) + source_subset = ( + target_subset + if source.offset is None + else ",".join( + f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" + for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) + ) + ) else: assert len(domain) == 0 - subset = "0" + target_subset = "0" + source_subset = "0" if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -501,17 +540,21 @@ def visit_SetAt( target_state = sdfg.add_state_after(state, f"post_{state.label}") # create new access nodes in the target state target_state.add_nedge( - target_state.add_access(temp.dc_node.data), + target_state.add_access(source.dc_node.data), target_state.add_access(target.dc_node.data), - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) # remove isolated access node state.remove_node(target.dc_node) else: state.add_nedge( - temp.dc_node, + source.dc_node, target.dc_node, - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) return target_state or state @@ -574,17 +617,65 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] + def flatten_tuples( + name: str, + arg: gtir_builtin_translators.FieldopResult, + ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: + if isinstance(arg, tuple): + tuple_type = _get_tuple_type(arg) + tuple_field_names = [ + arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) + ] + tuple_args = zip(tuple_field_names, arg, strict=True) + return list( + itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) + ) + else: + return [(name, arg)] + + lambda_arg_nodes = dict( + itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + ) + # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } + def get_field_domain_offset( + p_name: str, p_type: ts.DataType + ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: + if isinstance(p_type, ts.FieldType): + if p_name in lambda_arg_nodes: + arg = lambda_arg_nodes[p_name] + assert isinstance(arg, gtir_builtin_translators.FieldopData) + return {p_name: arg.offset} + elif field_domain_offset := self.field_offsets.get(p_name, None): + return {p_name: field_domain_offset} + elif isinstance(p_type, ts.TupleType): + p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + return functools.reduce( + lambda field_offsets, field: ( + field_offsets | get_field_domain_offset(field[0], field[1]) + ), + p_fields, + {}, + ) + return {} + + # populate mapping from field name to domain offset + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for p_name, p_type in lambda_symbols.items(): + lambda_field_offsets |= get_field_domain_offset(p_name, p_type) + # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) + lambda_translator = GTIRToSDFG( + self.offset_provider_type, lambda_symbols, lambda_field_offsets + ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -603,30 +694,11 @@ def visit_Lambda( head_state=nstate, ) - def _flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = dace_gtir_utils.get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[_flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - # Process lambda inputs # # All input arguments are passed as parameters to the nested SDFG, therefore # we they are stored as non-transient array and scalar objects. # - lambda_arg_nodes = dict( - itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) - ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) @@ -739,7 +811,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -748,7 +820,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index caec6cd87e..118f0449c8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import itertools -from typing import Any, Dict, TypeVar +from typing import Dict, TypeVar import dace @@ -58,15 +58,6 @@ def get_tuple_fields( return fields -def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9c52ea81c3..f5191fbaaa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -47,7 +47,7 @@ VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { - "IDim": IDim, + IDim.value: IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() @@ -735,13 +735,13 @@ def test_gtir_cartesian_shift_left(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -749,13 +749,15 @@ def test_gtir_cartesian_shift_left(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -764,14 +766,14 @@ def test_gtir_cartesian_shift_left(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -828,13 +830,13 @@ def test_gtir_cartesian_shift_right(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", -OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, -OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -842,13 +844,15 @@ def test_gtir_cartesian_shift_right(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -857,14 +861,14 @@ def test_gtir_cartesian_shift_right(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -1539,6 +1543,91 @@ def test_gtir_reduce_with_cond_neighbors(): assert np.allclose(v, v_ref) +def test_gtir_symbolic_domain(): + MARGIN = 2 + assert MARGIN < N + OFFSET = 1000 * 1000 * 1000 + domain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + left_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + ) + right_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + ) + shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) + shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) + testee = gtir.Program( + id="symbolic_domain", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "xᐞ1", + im.op_as_fieldop("multiplies", left_domain)( + 4.0, + im.as_fieldop( + shift_left_stencil, + left_domain, + )("x"), + ), + )( + im.let( + "xᐞ2", + im.op_as_fieldop("multiplies", right_domain)( + 3.0, + im.as_fieldop( + shift_right_stencil, + right_domain, + )("x"), + ), + )( + im.let( + "xᐞ3", + im.as_fieldop( + shift_right_stencil, + domain, + )("xᐞ1"), + )( + im.let( + "xᐞ4", + im.as_fieldop( + shift_left_stencil, + domain, + )("xᐞ2"), + )( + im.let("xᐞ5", im.op_as_fieldop("plus", domain)("xᐞ3", "xᐞ4"))( + im.op_as_fieldop("plus", domain)("xᐞ5", "x") + ) + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, **FSYMBOLS) + assert np.allclose(b, ref) + + def test_gtir_let_lambda(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) @@ -1722,7 +1811,7 @@ def test_gtir_let_lambda_with_cond(): def test_gtir_let_lambda_with_tuple1(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) testee = gtir.Program( id="let_lambda_with_tuple1", function_definitions=[], @@ -1753,10 +1842,12 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) + a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) + b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) sdfg(a, b, *z_fields, **FSYMBOLS) - assert np.allclose(z_fields[0], a) - assert np.allclose(z_fields[1], b) + assert np.allclose(z_fields[0], a_ref) + assert np.allclose(z_fields[1], b_ref) def test_gtir_let_lambda_with_tuple2(): From f6c219bd989e3c5325da1173bade4bff2ac9e650 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 26 Nov 2024 15:59:58 +0100 Subject: [PATCH 16/43] bug[next]: Fix SetAt type inference for ts.DeferredType (#1747) Fix to correctly handle tuples of ts.DeferredType. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/type_system/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 987eb0f308..249019769b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,10 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # the target can have fewer elements than the expr in which case the output from the # expression is simply discarded. expr_type = functools.reduce( - lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents + # `ts.DeferredType` only occurs for scans returning a tuple + if not isinstance(tuple_type, ts.DeferredType) + else ts.DeferredType(constraint=None), path, node.expr.type, ) From f6c0498dbffd85a80a32281e5a53bfb35e00e745 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 27 Nov 2024 09:55:46 +0100 Subject: [PATCH 17/43] feat[next][dace]: Lowering to SDFG of index builtin (#1751) Implements the lowering to SDFG of the GTIR index builtin. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 ++++ .../gtir_builtin_translators.py | 83 ++++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 2 + tests/next_tests/definitions.py | 1 - .../dace_tests/test_gtir_to_sdfg.py | 50 ++++++++++- 5 files changed, 134 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2864c7f727..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall: return _impl +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 60dcd8ddc9..94ab3a6f76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: the corresponding lower and upper bounds. The returned lower bound is inclusive, the upper bound is exclusive: [lower_bound, upper_bound[ """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = ( - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ) - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") return domain @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData: return result_temps +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, lower_bound, upper_bound = domain[0] + dim_index = dace_gtir_utils.get_map_variable(dim) + + field_dims, field_offset, field_shape = _get_field_layout(domain) + field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) + + output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) + output_node = state.add_access(output) + + sdfg_builder.add_mapped_tasklet( + "index", + state, + map_ranges={ + dim_index: f"{lower_bound}:{upper_bound}", + }, + inputs={}, + code=f"__val = {dim_index}", + outputs={ + "__val": dace.Memlet( + data=output_node.data, + subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), + ) + }, + input_nodes={}, + output_nodes={output_node.data: output_node}, + external_edges=True, + ) + + return FieldopData(output_node, field_type, field_offset) + + def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, @@ -777,6 +833,7 @@ def translate_symbol_ref( translate_as_fieldop, translate_broadcast_scalar, translate_if, + translate_index, translate_literal, translate_make_tuple, translate_tuple_get, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f15287e64c..6b5e164458 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -568,6 +568,8 @@ def visit_FunCall( # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 01fd18897d..349d3e9f70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index f5191fbaaa..c7466b853f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,15 +12,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools import numpy as np import pytest -from gt4py.next import common as gtx_common, constructors +from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -1973,3 +1973,49 @@ def test_gtir_if_values(): sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_cast", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.empty(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref) From 3ece412f0d78f32893d8f01ed0e74c8b38388854 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 28 Nov 2024 13:13:55 -0500 Subject: [PATCH 18/43] fix[cartesian]: Deactivate K offset write in `gt:gpu` (#1755) Following the issue logged as https://github.com/GridTools/gt4py/issues/1754 we are deactivating the K-offset write feature until we can figure out why it's failing. I will monitor any activity on the ticket if users are hit by this. --------- Co-authored-by: Hannes Vogt --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 7 +++++++ .../multi_feature_tests/test_code_generation.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ade05921ef..f155ea6209 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list: loc=nodes.Location.from_ast_node(t), ) + if self.backend_name in ["gt:gpu"]: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." + "Please refer to https://github.com/GridTools/gt4py/issues/1754.", + loc=nodes.Location.from_ast_node(t), + ) + if not self._is_known(name): if name in self.temp_decls: field_decl = self.temp_decls[name] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c4d07d7337..7c4956b3ef 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -667,6 +667,10 @@ def test_K_offset_write_conditional(backend): pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) From 886058496c1ebcb90ba530a796213d1fec7c7095 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 08:46:06 +0100 Subject: [PATCH 19/43] refact[next][dace]: Helper function for field operator constructor (#1743) Includes refactoring of the code for construction of field operators, in order to make it usable by the three lowering functions that construct fields: `translate_as_fieldop()`, `translate_broadcast_scalar()`, and `translate_index()`. --- .../gtir_builtin_translators.py | 242 +++++++----------- 1 file changed, 94 insertions(+), 148 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 94ab3a6f76..ff011c4193 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,11 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -229,40 +233,75 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_temporary_field( +def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, node_type: ts.FieldType, - dataflow_output: gtir_dataflow.DataflowOutputEdge, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Sequence[gtir_dataflow.DataflowInputEdge], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: - """Helper method to allocate a temporary field where to write the output of a field operator.""" + """ + Helper method to allocate a temporary field to store the output of a field operator. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edge: Edge representing the dataflow output data. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ field_dims, field_offset, field_shape = _get_field_layout(domain) + field_indices = _get_domain_indices(field_dims, field_offset) + + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - output_desc = dataflow_output.result.dc_node.desc(sdfg) - if isinstance(output_desc, dace.data.Array): + field_subset = sbs.Range.from_indices(field_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == node_type.dtype + assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + field_dtype = output_edge.result.gt_dtype + else: assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) + assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_offset.extend(output_desc.offset) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") + assert output_edge.result.gt_dtype.offset_type is not None + field_dims.append(output_edge.result.gt_dtype.offset_type) + field_shape.extend(dataflow_output_desc.shape) + field_offset.extend(dataflow_output_desc.offset) + field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) - field_node = state.add_access(temp_name) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) - if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): - field_dtype = dataflow_output.result.gt_dtype - else: - assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = dataflow_output.result.gt_dtype.element_type - assert dataflow_output.result.gt_dtype.offset_type is not None - field_dims.append(dataflow_output.result.gt_dtype.offset_type) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) return FieldopData( field_node, @@ -341,7 +380,8 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -349,117 +389,18 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_dims, domain_offsets, _ = zip(*domain) - domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.dc_node.desc(sdfg) - - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_indices) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.dc_node, output_subset) - - return result_field - + input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) -def translate_broadcast_scalar( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, -) -> FieldopResult: - """ - Generates the dataflow subgraph for the 'as_fieldop' builtin function for the - special case where the argument to 'as_fieldop' is a 'deref' scalar expression, - rather than a lambda function. This case corresponds to broadcasting the scalar - value over the field domain. Therefore, it is lowered to a mapped tasklet that - just writes the scalar value out to all elements of the result field. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert cpm.is_ref_to(stencil_expr, "deref") - - domain = extract_domain(domain_expr) - output_dims, output_offset, output_shape = _get_field_layout(domain) - output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) - - assert len(node.args) == 1 - scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - - if isinstance(node.args[0].type, ts.ScalarType): - assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) - input_subset = ( - str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" - ) - input_node = scalar_expr.dc_node - gt_dtype = node.args[0].type - elif isinstance(node.args[0].type, ts.FieldType): - assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field - input_subset = "0" - else: - input_subset = scalar_expr.get_memlet_subset(sdfg) - - input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype - else: - raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - - output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( - "broadcast", - state, - map_ranges={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, - code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, - input_nodes={input_node.data: input_node}, - output_nodes={output_node.data: output_node}, - external_edges=True, + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) - return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) - def translate_if( node: gtir.Node, @@ -567,38 +508,44 @@ def translate_index( index values to a transient array. The extent of the index range is taken from the domain information that should be present in the node annex. """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + assert "domain" in node.annex domain = extract_domain(node.annex.domain) assert len(domain) == 1 - dim, lower_bound, upper_bound = domain[0] + dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) - - output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( + index_data = sdfg.temp_data_name() + sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( "index", state, - map_ranges={ - dim_index: f"{lower_bound}:{upper_bound}", - }, inputs={}, + outputs={"__val"}, code=f"__val = {dim_index}", - outputs={ - "__val": dace.Memlet( - data=output_node.data, - subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), - ) - }, - input_nodes={}, - output_nodes={output_node.data: output_node}, - external_edges=True, + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), ) - return FieldopData(output_node, field_type, field_offset) + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + ) def _get_data_nodes( @@ -831,7 +778,6 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, - translate_broadcast_scalar, translate_if, translate_index, translate_literal, From d9b38f476ee5df1995d27b7497037f3f19c9b6e6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 29 Nov 2024 02:50:43 -0500 Subject: [PATCH 20/43] hotfix[cartesian]: Fixing k offset write utest deactivate (#1757) Missed a utest in #1755 --- .../multi_feature_tests/test_code_generation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 7c4956b3ef..e51b3ef09d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -582,13 +582,17 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -660,7 +664,7 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: From 791f67d031127872fc6375819267f59faeaf85ba Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 10:02:34 +0100 Subject: [PATCH 21/43] test[next]: Fix flaky failure in GTIR to SDFG tests (#1759) The SDFG name has to be unique to avoid issues with parallel build in CI tests. --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c7466b853f..b1ba4ccf22 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1984,7 +1984,7 @@ def test_gtir_index(): ) testee = gtir.Program( - id="gtir_cast", + id="gtir_index", function_definitions=[], params=[ gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), From 04513ba859d5ed55ea99999f6fd826a2a542a627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 29 Nov 2024 13:57:10 +0100 Subject: [PATCH 22/43] fix[next]: use current working directory as default cache folder root (#1744) Change the root folder of the gt4py cache directory from the system temp folder to the current working directory, which is more visible and also avoids polluting shared filesystems in hpc clusters. --------- Co-authored-by: Hannes Vogt --- src/gt4py/next/config.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] From d581060e5c6e8b6f64b72cce041d539956ca4727 Mon Sep 17 00:00:00 2001 From: SF-N Date: Sat, 30 Nov 2024 09:39:26 +0100 Subject: [PATCH 23/43] bug[next]: ConstantFolding after create_global_tmps (#1756) Do `ConstantFolding` within `domain_union` to avoid nested minima and maxima by `create_global_tmps` --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) From a26d91f409ea5d67f168bbbc4a2157df2ed1080b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:31:13 +0100 Subject: [PATCH 24/43] fix[next]: Fix annex & type preservation in inline_lambdas (#1760) Co-authored-by: SF-N --- src/gt4py/next/iterator/transforms/inline_lambdas.py | 11 +++++------ src/gt4py/next/iterator/transforms/remap_symbols.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 7 +++++-- .../transforms_tests/test_inline_lambdas.py | 7 +++++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 5ec9ec5d0b..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,7 +97,6 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: new_expr = ir.FunCall( fun=ir.Lambda( @@ -111,11 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) - for attr in ("type", "recorded_shifts", "domain"): - if hasattr(node.annex, attr): - setattr(new_expr.annex, attr, getattr(node.annex, attr)) - itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) - return new_expr + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 08d896121d..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 249019769b..ffca6cc7a7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) # type: ignore[arg-type] + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 2e0a83d33b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -84,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) From 99c53004663b0b58c7ce8335bcc30e347d3686b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 22:08:39 +0100 Subject: [PATCH 25/43] refactor[next]: Use `set_at` & `as_fieldop` instead of `closure` in iterator tests (#1691) --- .../test_cartesian_offset_provider.py | 12 +++--- .../iterator_tests/test_conditional.py | 2 +- .../test_strided_offset_provider.py | 7 ++-- .../iterator_tests/test_trivial.py | 10 ++--- .../iterator_tests/test_tuple.py | 28 +++++-------- .../iterator_tests/test_anton_toy.py | 21 +++++----- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++----------- .../iterator_tests/test_hdiff.py | 10 ++--- 8 files changed, 55 insertions(+), 75 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bde55bfd2..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,8 +10,8 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor from gt4py.next.iterator.embedded import StridedConnectivityField @@ -36,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..fe89fe7c9d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..39d0bd69c3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 3ce9d6b470..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 4487681abf..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,8 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas @@ -212,12 +211,8 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..e44e92013f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) @pytest.mark.uses_origin From 6f49699f00ceb9e466fa4448bab779bc061df047 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 2 Dec 2024 13:09:47 +0100 Subject: [PATCH 26/43] style[eve]: remove unused imports and fix typos (#1748) Small cleanup PR in the eve framework: - Removes a stale `.gitignore` file. As far as I understood from the git history, earlier versions of this codebase had many `.gitignore` files in many places. Looks like this one is a leftover from a previous time. - Remove a couple of stale includes. The language server marked them as unused and since tests still pass, I guess we really don't need them anymore. - Fixed a couple of typos in comments - Fixed two typos in the github PR template --- .github/pull_request_template.md | 4 ++-- src/gt4py/eve/.gitignore | 1 - src/gt4py/eve/__init__.py | 14 ++------------ src/gt4py/eve/codegen.py | 6 +++--- src/gt4py/eve/datamodels/__init__.py | 4 ++-- src/gt4py/eve/datamodels/core.py | 16 ++++++++-------- src/gt4py/eve/extended_typing.py | 4 ---- src/gt4py/eve/trees.py | 8 -------- src/gt4py/eve/type_validation.py | 2 +- src/gt4py/eve/utils.py | 2 +- src/gt4py/next/ffront/decorator.py | 2 +- 11 files changed, 20 insertions(+), 43 deletions(-) delete mode 100644 src/gt4py/eve/.gitignore diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7284a7df04..83304a9c62 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi - test: Adding missing tests or correcting existing tests : cartesian | eve | next | storage - # ONLY if changes are limited to a specific subsytem + # ONLY if changes are limited to a specific subsystem - PR Description: @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi ## Requirements - [ ] All fixes and/or new features come with corresponding tests. -- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. +- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. If this PR contains code authored by new contributors please make sure: diff --git a/src/gt4py/eve/.gitignore b/src/gt4py/eve/.gitignore deleted file mode 100644 index 050cda3ca5..0000000000 --- a/src/gt4py/eve/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_version.py diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 0b8cfa7d62..5adac47da3 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -24,8 +24,7 @@ """ -from __future__ import annotations # isort:skip - +from __future__ import annotations from .concepts import ( AnnexManager, @@ -89,15 +88,6 @@ "SymbolRef", "VType", "register_annex_user", - "# datamodels" "Coerced", - "DataModel", - "FrozenModel", - "GenericDataModel", - "Unchecked", - "concretize", - "datamodel", - "field", - "frozenmodel", # datamodels "Coerced", "DataModel", @@ -122,7 +112,7 @@ "pre_walk_values", "walk_items", "walk_values", - "# type_definition", + # type_definitions "NOTHING", "ConstrainedStr", "Enum", diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 15fda4f3b4..3869ff313b 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -347,7 +347,7 @@ def __str__(self) -> str: class Template(Protocol): """Protocol (abstract base class) defining the Template interface. - Direct subclassess of this base class only need to implement the + Direct subclasses of this base class only need to implement the abstract methods to adapt different template engines to this interface. @@ -654,8 +654,8 @@ def apply( # redefinition of symbol Args: root: An IR node. - node_templates (optiona): see :class:`NodeDumper`. - dump_function (optiona): see :class:`NodeDumper`. + node_templates (optional): see :class:`NodeDumper`. + dump_function (optional): see :class:`NodeDumper`. ``**kwargs`` (optional): custom extra parameters forwarded to `visit_NODE_TYPE_NAME()`. Returns: diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 68ddea2510..6fd9c7bb21 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -11,7 +11,7 @@ Data Models can be considered as enhanced `attrs `_ / `dataclasses `_ providing additional features like automatic run-time type validation. Values assigned to fields -at initialization can be validated with automatic type checkings using the +at initialization can be validated with automatic type checking using the field type definition. Custom field validation methods can also be added with the :func:`validator` decorator, and global instance validation methods with :func:`root_validator`. @@ -33,7 +33,7 @@ 1. ``__init__()``. a. If a custom ``__init__`` already exists in the class, it will not be overwritten. - It is your responsability to call ``__auto_init__`` from there to obtain + It is your responsibility to call ``__auto_init__`` from there to obtain the described behavior. b. If there is not custom ``__init__``, the one generated by datamodels will be called first. diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..1b0e995156 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -24,7 +24,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -270,7 +270,7 @@ def datamodel( @overload -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Type[_T], /, *, @@ -289,7 +289,7 @@ def datamodel( # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -867,7 +867,7 @@ def _substitute_typevars( def _make_counting_attr_from_attribute( field_attrib: Attribute, *, include_type: bool = False, **kwargs: Any -) -> Any: # attr.s lies a bit in some typing definitons +) -> Any: # attr.s lies a bit in some typing definitions args = [ "default", "validator", @@ -965,7 +965,7 @@ def _type_converter(value: Any) -> _T: return value if isinstance(value, type_annotation) else type_annotation(value) except Exception as error: raise TypeError( - f"Error during coertion of given value '{value}' for field '{name}'." + f"Error during coercion of given value '{value}' for field '{name}'." ) from error return _type_converter @@ -996,7 +996,7 @@ def _type_converter(value: Any) -> _T: return _make_type_converter(origin_type, name) raise exceptions.EveTypeError( - f"Automatic type coertion for {type_annotation} types is not supported." + f"Automatic type coercion for {type_annotation} types is not supported." ) @@ -1085,7 +1085,7 @@ def _make_datamodel( ) else: - # Create field converter if automatic coertion is enabled + # Create field converter if automatic coercion is enabled converter: TypeConverter = cast( TypeConverter, _make_type_converter(type_hint, qualified_field_name) if coerce_field else None, @@ -1099,7 +1099,7 @@ def _make_datamodel( if isinstance(attr_value_in_cls, _KNOWN_MUTABLE_TYPES): warnings.warn( f"'{attr_value_in_cls.__class__.__name__}' value used as default in '{cls.__name__}.{key}'.\n" - "Mutable types should not defbe normally used as field defaults (use 'default_factory' instead).", + "Mutable types should not be used as field defaults (use 'default_factory' instead).", stacklevel=_stacklevel_offset + 2, ) setattr( diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e276f3bccf..bf44824b49 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -14,12 +14,8 @@ from __future__ import annotations -import abc as _abc import array as _array -import collections.abc as _collections_abc -import ctypes as _ctypes import dataclasses as _dataclasses -import enum as _enum import functools as _functools import inspect as _inspect import mmap as _mmap diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index c8e8658413..8a3cc30f4b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -31,14 +31,6 @@ from .type_definitions import Enum -try: - # For performance reasons, try to use cytoolz when possible (using cython) - import cytoolz as toolz -except ModuleNotFoundError: - # Fall back to pure Python toolz - import toolz # noqa: F401 [unused-import] - - TreeKey = Union[int, str] diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..e150832295 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -311,7 +311,7 @@ def __call__( # ... # # Since this can be an arbitrary type (not something regular like a collection) there is - # no way to check if the type parameter is verifed in the actual instance. + # no way to check if the type parameter is verified in the actual instance. # The only check can be done at run-time is to verify that the value is an instance of # the original type, completely ignoring the annotation. Ideally, the static type checker # can do a better job to try figure out if the type parameter is ok ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8cb68845d7..2c66d39290 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -69,7 +69,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9ce07d01bb..61756f30c9 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -230,7 +230,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." ), stacklevel=2, ) From f57d6e916e17ee2ff574ba6096ccc21911d27533 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 2 Dec 2024 20:02:44 +0100 Subject: [PATCH 27/43] fix[next]: Guard diskcache creation by file lock (#1745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The disk cache used to cache compilation in the gtfn backend has a race condition manifesting itself in `sqlite3.OperationalError: database is locked` errors if multiple python processes try to initialize the `diskcache.Cache` object concurrently. This PR fixes this by guarding the object creation by a file-based lock in the same directory as the database. While this issue occurred frequently and was observed to be fixed on distributed file systems, the lock does not guarantee correct behavior in particular for accesses to the cache (beyond opening) since the underlying SQLite database is unreliable when stored on an NFS based file system. It does however ensure correctness of concurrent cache accesses on a local file system. See more information here: https://grantjenks.com/docs/diskcache/tutorial.html#settings https://www.sqlite.org/faq.html#q5 https://github.com/tox-dev/filelock/issues/73 NFS safe locking: https://gitlab.com/warsaw/flufl.lock [Barry Warsaw / FLUFL Lock · GitLab](https://gitlab.com/warsaw/flufl.lock) --- .pre-commit-config.yaml | 1 + constraints.txt | 8 ++--- min-extra-requirements-test.txt | 1 + min-requirements-test.txt | 1 + pyproject.toml | 1 + requirements-dev.txt | 8 ++--- .../next/program_processors/runners/gtfn.py | 32 ++++++++++++++++--- 7 files changed, 40 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c3b6e693f..7e1870c67f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -102,6 +102,7 @@ repos: - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 + - filelock==3.16.1 - frozendict==2.4.6 - gridtools-cpp==2.3.8 - importlib-resources==6.4.5 diff --git a/constraints.txt b/constraints.txt index b4b8bc00d4..f039fa2125 100644 --- a/constraints.txt +++ b/constraints.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat -filelock==3.16.1 # via tox, virtualenv +filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.9.2 # via bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via pydantic +pydantic==2.10.0 # via bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via pydantic pydantic-settings==2.6.1 # via bump-my-version pydot==3.0.2 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.3 # via -r requirements-dev.in +tach==0.14.4 # via -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 57c0d3969d..d7679a1f0f 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,6 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 81a1c2dea3..cf505e88d6 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,6 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index 02d301957c..1e24094fa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', + 'filelock>=3.0.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f95779fd5..6542be36f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat -filelock==3.16.1 # via -c constraints.txt, tox, virtualenv +filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.9.2 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via -c constraints.txt, pydantic +pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version pydot==3.0.2 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1f3778f227..55f479c665 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,11 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import pathlib +import tempfile import warnings from typing import Any, Optional import diskcache import factory +import filelock import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -139,13 +142,34 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: class FileCache(diskcache.Cache): """ - This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, - i.e. it ensures that any resources associated with the cache are properly - released when the instance is garbage collected. + This class extends `diskcache.Cache` to ensure the cache is properly + - opened when accessed by multiple processes using a file lock. This guards the creating of the + cache object, which has been reported to cause `sqlite3.OperationalError: database is locked` + errors and slow startup times when multiple processes access the cache concurrently. While this + issue occurred frequently and was observed to be fixed on distributed file systems, the lock + does not guarantee correct behavior in particular for accesses to the cache (beyond opening) + since the underlying SQLite database is unreliable when stored on an NFS based file system. + It does however ensure correctness of concurrent cache accesses on a local file system. See + #1745 for more details. + - closed upon deletion, i.e. it ensures that any resources associated with the cache are + properly released when the instance is garbage collected. """ + def __init__(self, directory: Optional[str | pathlib.Path] = None, **settings: Any) -> None: + if directory: + lock_dir = pathlib.Path(directory).parent + else: + lock_dir = pathlib.Path(tempfile.gettempdir()) + + lock = filelock.FileLock(lock_dir / "file_cache.lock") + with lock: + super().__init__(directory=directory, **settings) + + self._init_complete = True + def __del__(self) -> None: - self.close() + if getattr(self, "_init_complete", False): # skip if `__init__` didn't finished + self.close() class GTFNCompileWorkflowFactory(factory.Factory): From e5abcd20839e35c5480b512e1c2ef9b6f01c60e4 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:55:53 +0100 Subject: [PATCH 28/43] bug[next]: Fix codegen in gtfn for unused vertical offset provider (#1746) Providing an offest provider for a vertical dimension without using that dimension in a program, e.g. no arguments are fields defined on K, resulted in erroneous C++ code. --- .../codegens/gtfn/itir_to_gtfn_ir.py | 3 +++ tests/next_tests/integration_tests/cases.py | 10 +++++++++- .../ffront_tests/test_execution.py | 15 +++++++++++++++ .../ffront_tests/test_gt4py_builtins.py | 17 ++++++++++------- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 129d81d6f9..dc0012b041 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -198,6 +198,9 @@ def _collect_offset_definitions( "Mapping an offset to a horizontal dimension in unstructured is not allowed." ) # create alias from vertical offset to vertical dimension + offset_definitions[dim.value] = TagDefinition( + name=Sym(id=dim.value), alias=_vertical_dimension + ) offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 9fb7850666..759cd1cf1f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -499,13 +499,21 @@ def unstructured_case( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, ) +@pytest.fixture +def unstructured_case_3d(unstructured_case): + return dataclasses.replace( + unstructured_case, + default_sizes={**unstructured_case.default_sizes, KDim: 10}, + offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + ) + + def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1a51e3667d..0d994d1b22 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -41,6 +41,7 @@ Edge, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -93,6 +94,20 @@ def testee(a: cases.VField) -> cases.EField: ) +def test_horizontal_only_with_3d_mesh(unstructured_case_3d): + # test field operator operating only on horizontal fields while using an offset provider + # including a vertical dimension. + @gtx.field_operator + def testee(a: cases.VField) -> cases.VField: + return a + + cases.verify_with_default_data( + unstructured_case_3d, + testee, + ref=lambda a: a, + ) + + @pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7648d34db7..ab1c625fef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -29,6 +29,7 @@ Vertex, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -105,10 +106,10 @@ def reduction_ke_field( @pytest.mark.parametrize( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) -def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].ndarray +def test_neighbor_sum(unstructured_case_3d, fop): + v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray - edge_f = cases.allocate(unstructured_case, fop, "edge_f")() + edge_f = cases.allocate(unstructured_case_3d, fop, "edge_f")() local_dim_idx = edge_f.domain.dims.index(Edge) + 1 adv_indexing = tuple( @@ -131,10 +132,10 @@ def test_neighbor_sum(unstructured_case, fop): where=broadcasted_table != common._DEFAULT_SKIP_VALUE, ) cases.verify( - unstructured_case, + unstructured_case_3d, fop, edge_f, - out=cases.allocate(unstructured_case, fop, cases.RETURN)(), + out=cases.allocate(unstructured_case_3d, fop, cases.RETURN)(), ref=ref, ) @@ -463,11 +464,13 @@ def conditional_program( ) -def test_promotion(unstructured_case): +def test_promotion(unstructured_case_3d): @gtx.field_operator def promotion( inp1: gtx.Field[[Edge, KDim], float64], inp2: gtx.Field[[KDim], float64] ) -> gtx.Field[[Edge, KDim], float64]: return inp1 / inp2 - cases.verify_with_default_data(unstructured_case, promotion, ref=lambda inp1, inp2: inp1 / inp2) + cases.verify_with_default_data( + unstructured_case_3d, promotion, ref=lambda inp1, inp2: inp1 / inp2 + ) From a2551acc0cf832ed9628b2930264e1d3998cebbf Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 3 Dec 2024 21:51:05 +0100 Subject: [PATCH 29/43] feat[next]: Remove dace_iterator backend and pass_manager_legacy (#1753) The dace orchestration tests are temporarily skipped until #1742 is merged. The dace backend with SDFG optimization is temporarily disabled in unit tests until #1639 is merged. A second PR will reorganize the files in dace backend module. --- .../transforms/pass_manager_legacy.py | 181 -- .../next/program_processors/runners/dace.py | 62 +- .../runners/dace_common/dace_backend.py | 30 +- .../runners/dace_common/utility.py | 9 +- .../runners/dace_common/workflow.py | 2 +- .../runners/dace_iterator/__init__.py | 377 ---- .../runners/dace_iterator/itir_to_sdfg.py | 809 --------- .../runners/dace_iterator/itir_to_tasklet.py | 1564 ----------------- .../runners/dace_iterator/utility.py | 149 -- .../runners/dace_iterator/workflow.py | 150 -- tests/next_tests/definitions.py | 41 +- .../feature_tests/dace/test_orchestration.py | 37 +- .../ffront_tests/ffront_test_utils.py | 4 +- 13 files changed, 74 insertions(+), 3341 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/pass_manager_legacy.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/__init__.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/utility.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/workflow.py diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py deleted file mode 100644 index 94c962e92d..0000000000 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ /dev/null @@ -1,181 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR -import enum -from typing import Callable, Optional - -from gt4py.eve import utils as eve_utils -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs -from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet -from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts -from gt4py.next.iterator.transforms.merge_let import MergeLet -from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce - - -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir - - -def apply_common_transforms( - ir: itir.Node, - *, - lift_mode=None, - offset_provider=None, - unroll_reduce=False, - common_subexpression_elimination=True, - force_inline_lambda_args=False, - unconditionally_collapse_tuples=False, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - offset_provider_type: Optional[common.OffsetProviderType] = None, -) -> itir.Program: - assert isinstance(ir, itir.FencilDefinition) - # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps - if offset_provider_type is None: - offset_provider_type = common.offset_provider_to_type(offset_provider) - - ir = fencil_to_program.FencilToProgram().apply(ir) - icdlv_uids = eve_utils.UIDGenerator() - - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) - ir = MergeLet().visit(ir) - ir = inline_fundefs.InlineFundefs().visit(ir) - - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) - ir = NormalizeShifts().visit(ir) - - for _ in range(10): - inlined = ir - - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) - # This pass is required to be in the loop such that when an `if_` call with tuple arguments - # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( - inlined, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) - - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - - if lift_mode != LiftMode.FORCE_INLINE: - raise NotImplementedError() - - # Since `CollapseTuple` relies on the type inference which does not support returning tuples - # larger than the number of closure outputs as given by the unconditional collapse, we can - # only run the unconditional version here instead of in the loop above. - if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( - ir, - ignore_tuple_size=True, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) - - ir = NormalizeShifts().visit(ir) - - ir = FuseMaps().visit(ir) - ir = CollapseListGet().visit(ir) - - if unroll_reduce: - for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) - if unrolled == ir: - break - ir = unrolled - ir = CollapseListGet().visit(ir) - ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) - ir = NormalizeShifts().visit(ir) - else: - raise RuntimeError("Reduction unrolling failed.") - - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program - ir = MergeLet().visit(ir) - - ir = InlineLambdas.apply( - ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args - ) - - assert isinstance(ir, itir.Program) - return ir diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 95186e0b5d..1b3b930818 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,45 +8,34 @@ import factory +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators from gt4py.next import backend +from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow -from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory -class DaCeIteratorBackendFactory(GTFNBackendFactory): +class DaCeFieldviewBackendFactory(GTFNBackendFactory): + class Meta: + model = backend.Backend + class Params: - otf_workflow = factory.SubFactory( - dace_iterator_workflow.DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), + name_device = "cpu" + name_cached = "" + name_postfix = "" + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + name_device="gpu", ) - auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_postfix="_opt" + cached = factory.Trait( + executor=factory.LazyAttribute( + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + ), + name_cached="_cached", ) - use_field_canonical_representation: bool = False - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" - ) - - transforms = backend.LEGACY_TRANSFORMS - - -run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) -run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False) - -run_dace_gpu = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=True) -run_dace_gpu_noopt = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=False) - -itir_cpu = run_dace_cpu -itir_gpu = run_dace_gpu - - -class DaCeFieldviewBackendFactory(GTFNBackendFactory): - class Params: + device_type = core_defs.DeviceType.CPU otf_workflow = factory.SubFactory( dace_fieldview_workflow.DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), @@ -55,11 +44,16 @@ class Params: auto_optimize = factory.Trait(name_postfix="_opt") name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" + lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" ) + executor = factory.LazyAttribute(lambda o: o.otf_workflow) + allocator = next_allocators.StandardCPUFieldBufferAllocator() transforms = backend.DEFAULT_TRANSFORMS -gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) -gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) +run_dace_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=True) +run_dace_cpu_noopt = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) + +run_dace_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=True) +run_dace_gpu_noopt = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 56ba08015b..90e7e07ad5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -24,7 +24,7 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: +def _convert_arg(arg: Any, sdfg_param: str) -> Any: if not isinstance(arg, gtx_common.Field): return arg if len(arg.domain.dims) == 0: @@ -41,26 +41,14 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: raise RuntimeError( f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." ) - if not use_field_canonical_representation: - return arg.ndarray - # the canonical representation requires alphabetical ordering of the dimensions in field domain definition - sorted_dims = dace_utils.get_sorted_dims(arg.domain.dims) - ndim = len(sorted_dims) - dim_indices = [dim_index for dim_index, _ in sorted_dims] - if isinstance(arg.ndarray, np.ndarray): - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) - else: - assert cp is not None and isinstance(arg.ndarray, cp.ndarray) - return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) - - -def _get_args( - sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool -) -> dict[str, Any]: + return arg.ndarray + + +def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) return { - sdfg_param: _convert_arg(arg, sdfg_param, use_field_canonical_representation) + sdfg_param: _convert_arg(arg, sdfg_param) for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) } @@ -154,10 +142,10 @@ def get_sdfg_conn_args( def get_sdfg_args( sdfg: dace.SDFG, + offset_provider: gtx_common.OffsetProvider, *args: Any, check_args: bool = False, on_gpu: bool = False, - use_field_canonical_representation: bool = True, **kwargs: Any, ) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. @@ -166,10 +154,10 @@ def get_sdfg_args( Args: sdfg: The SDFG for which we want to get the arguments. + offset_provider: Offset provider. """ - offset_provider = kwargs["offset_provider"] - dace_args = _get_args(sdfg, args, use_field_canonical_representation) + dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 3e96ef3cec..ac15bc1cbf 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Literal, Optional, Sequence +from typing import Final, Literal, Optional import dace @@ -96,10 +96,3 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } - - -def get_sorted_dims( - dims: Sequence[gtx_common.Dimension], -) -> Sequence[tuple[int, gtx_common.Dimension]]: - """Sort list of dimensions in alphabetical order.""" - return sorted(enumerate(dims), key=lambda v: v[1].value) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 91e83dba9d..5d9ac863c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -150,9 +150,9 @@ def decorated_program( sdfg_args = dace_backend.get_sdfg_args( sdfg, + offset_provider, *args, check_args=False, - offset_provider=offset_provider, on_gpu=on_gpu, use_field_canonical_representation=use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py deleted file mode 100644 index ef09cf51cd..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ /dev/null @@ -1,377 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dataclasses -import warnings -from collections import OrderedDict -from collections.abc import Callable, Sequence -from dataclasses import field -from inspect import currentframe, getframeinfo -from pathlib import Path -from typing import Any, ClassVar, Optional - -import dace -import numpy as np -from dace.sdfg import utils as sdutils -from dace.transformation.auto import auto_optimize as autoopt - -import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.ffront import decorator -from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.iterator.ir import SymRef -from gt4py.next.iterator.transforms import ( - pass_manager_legacy as legacy_itir_transforms, - program_to_fencil, -) -from gt4py.next.iterator.type_system import inference as itir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .itir_to_sdfg import ItirToSDFG - - -def preprocess_program( - program: itir.FencilDefinition, - offset_provider_type: common.OffsetProviderType, - lift_mode: legacy_itir_transforms.LiftMode, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - unroll_reduce: bool = False, -): - node = legacy_itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider_type=offset_provider_type, - symbolic_domain_sizes=symbolic_domain_sizes, - temporary_extraction_heuristics=temporary_extraction_heuristics, - unroll_reduce=unroll_reduce, - ) - - node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - - if isinstance(node, itir.Program): - fencil_definition = program_to_fencil.program_to_fencil(node) - tmps = node.declarations - assert all(isinstance(tmp, itir.Temporary) for tmp in tmps) - else: - raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.") - - return fencil_definition, tmps - - -def build_sdfg_from_itir( - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - auto_optimize: bool = False, - on_gpu: bool = False, - column_axis: Optional[common.Dimension] = None, - lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - load_sdfg_from_file: bool = False, - save_sdfg: bool = True, - use_field_canonical_representation: bool = True, -) -> dace.SDFG: - """Translate a Fencil into an SDFG. - - Args: - program: The Fencil that should be translated. - arg_types: Types of the arguments passed to the fencil. - offset_provider: The set of offset providers that should be used. - auto_optimize: Apply DaCe's `auto_optimize` heuristic. - on_gpu: Performs the translation for GPU, defaults to `False`. - column_axis: The column axis to be used, defaults to `None`. - lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. - symbolic_domain_sizes: Used for generation of liskov bindings when temporaries are enabled. - load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. - save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. - use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - """ - - sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" - if load_sdfg_from_file and Path(sdfg_filename).exists(): - sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename) - sdfg.validate() - return sdfg - - # visit ITIR and generate SDFG - program, tmps = preprocess_program( - program, - offset_provider_type, - lift_mode, - symbolic_domain_sizes, - temporary_extraction_heuristics, - ) - sdfg_genenerator = ItirToSDFG( - list(arg_types), - offset_provider_type, - tmps, - use_field_canonical_representation, - column_axis, - ) - sdfg = sdfg_genenerator.visit(program) - if sdfg is None: - raise RuntimeError(f"Visit failed for program {program.id}.") - - for nested_sdfg in sdfg.all_sdfgs_recursive(): - if not nested_sdfg.debuginfo: - _, frameinfo = ( - warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.", - stacklevel=2, - ), - getframeinfo(currentframe()), # type: ignore[arg-type] - ) - nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename - ) - - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) - - # run DaCe transformations to simplify the SDFG - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO: Investigate performance improvement from SDFG specialization with constant symbols, - # for array shape and strides, although this would imply JIT compilation. - symbols: dict[str, int] = {} - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) - elif on_gpu: - autoopt.apply_gpu_storage(sdfg) - - if on_gpu: - sdfg.apply_gpu_transformations() - - # Store the sdfg such that we can later reuse it. - if save_sdfg: - sdfg.save(sdfg_filename) - - return sdfg - - -@dataclasses.dataclass(frozen=True) -class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): - """Extension of GT4Py Program implementing the SDFGConvertible interface.""" - - sdfg_closure_vars: dict[str, Any] = field(default_factory=dict) - - # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, - # there is no name mangling of the connectivity tables used across the nested SDFGs - # since they share the same memory address. - connectivity_tables_data_descriptors: ClassVar[ - dict[str, dace.data.Array] - ] = {} # symbolically defined - - def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: - if "dace" not in self.backend.name.lower(): # type: ignore[union-attr] - raise ValueError("The SDFG can be generated only for the DaCe backend.") - - params = {str(p.id): p.type for p in self.itir.params} - fields = {str(p.id): p.type for p in self.itir.params if hasattr(p.type, "dims")} - arg_types = [*params.values()] - - dace_parsed_args = [*args, *kwargs.values()] - gt4py_program_args = [*params.values()] - _crosscheck_dace_parsing(dace_parsed_args, gt4py_program_args) - - if self.connectivities is None: - raise ValueError( - "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." - ) - offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} - - sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] - self.itir, - arg_types, - offset_provider_type=offset_provider_type, - column_axis=kwargs.get("column_axis", None), - ) - self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ - - # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field - # Add them as dynamic properties to the SDFG - - assert all( - isinstance(in_field, SymRef) - for closure in self.itir.closures - for in_field in closure.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_fields = [ - str(in_field.id) # type: ignore[union-attr] # ensured by assert - for closure in self.itir.closures - for in_field in closure.inputs - if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert - ] - sdfg.gt4py_program_input_fields = { - in_field: dim - for in_field in input_fields - for dim in fields[in_field].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - output_fields = [] - for closure in self.itir.closures: - output = closure.output - if isinstance(output, itir.SymRef): - if str(output.id) in fields: - output_fields.append(str(output.id)) - else: - for arg in output.args: - if str(arg.id) in fields: # type: ignore[attr-defined] - output_fields.append(str(arg.id)) # type: ignore[attr-defined] - sdfg.gt4py_program_output_fields = { - output: dim - for output in output_fields - for dim in fields[output].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - sdfg.offset_providers_per_input_field = {} - itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider_type=offset_provider_type - ) - itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) - for closure in itir_tmp_fencil.closures: - params_shifts = itir_transforms.trace_shifts.trace_stencil( - closure.stencil, num_args=len(closure.inputs) - ) - for param, shifts in zip(closure.inputs, params_shifts): - assert isinstance( - param, SymRef - ) # backend only supports SymRef inputs, not `index` calls - if not isinstance(param.id, str): - continue - if param.id not in sdfg.gt4py_program_input_fields: - continue - sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) - - return sdfg - - def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: - """ - Returns the closure arrays of the SDFG represented by this object - as a mapping between array name and the corresponding value. - - The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. - The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that - the offset providers are not part of GT4Py Program's arguments. - Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. - """ - offset_provider_type = self.connectivities - - # Define DaCe symbols - connectivity_table_size_symbols = { - dace_utils.field_size_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - connectivity_table_stride_symbols = { - dace_utils.field_stride_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - symbols = {**connectivity_table_size_symbols, **connectivity_table_stride_symbols} - - # Define the storage location (e.g. CPU, GPU) of the connectivity tables - if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - if not isinstance(v, common.NeighborConnectivityType): - continue - if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: - Program.connectivity_tables_data_descriptors["storage"] = ( - self.sdfg_closure_vars[ - "sdfg.arrays" - ][dace_utils.connectivity_identifier(k)].storage - ) - break - - # Build the closure dictionary - closure_dict = {} - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - conn_id = dace_utils.connectivity_identifier(k) - if ( - isinstance(v, common.NeighborConnectivityType) - and conn_id in self.sdfg_closure_vars["sdfg.arrays"] - ): - if conn_id not in Program.connectivity_tables_data_descriptors: - Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, - shape=[ - symbols[dace_utils.field_size_symbol_name(conn_id, 0)], - symbols[dace_utils.field_size_symbol_name(conn_id, 1)], - ], - strides=[ - symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], - symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], - ], - storage=Program.connectivity_tables_data_descriptors["storage"], - ) - closure_dict[conn_id] = Program.connectivity_tables_data_descriptors[conn_id] - - return closure_dict - - def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: - args = [] - for arg in self.past_stage.past_node.params: - args.append(arg.id) - return (args, []) - - -def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> bool: - for dace_parsed_arg, gt4py_program_arg in zip(dace_parsed_args, gt4py_program_args): - if isinstance(dace_parsed_arg, dace.data.Scalar): - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) - elif isinstance( - dace_parsed_arg, (bool, int, float, str, np.bool_, np.integer, np.floating, np.str_) - ): # compile-time constant scalar - assert isinstance(gt4py_program_arg, ts.ScalarType) - if isinstance(dace_parsed_arg, (bool, np.bool_)): - assert gt4py_program_arg.kind == ts.ScalarKind.BOOL - elif isinstance(dace_parsed_arg, (int, np.integer)): - assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] - elif isinstance(dace_parsed_arg, (float, np.floating)): - assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] - elif isinstance(dace_parsed_arg, (str, np.str_)): - assert gt4py_program_arg.kind == ts.ScalarKind.STRING - elif isinstance(dace_parsed_arg, dace.data.Array): - assert isinstance(gt4py_program_arg, ts.FieldType) - assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) - elif isinstance( - dace_parsed_arg, (dace.data.Structure, dict, OrderedDict) - ): # offset_provider - continue - else: - raise ValueError(f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}") - - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py deleted file mode 100644 index 823943cfd5..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ /dev/null @@ -1,809 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import warnings -from typing import Optional, Sequence, cast - -import dace -from dace.sdfg.state import LoopRegion - -import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind, common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt - -from .itir_to_tasklet import ( - Context, - GatherOutputSymbolsPass, - PythonTaskletCodegen, - SymbolExpr, - TaskletExpr, - ValueExpr, - closure_to_tasklet_sdfg, - is_scan, -) -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_var_name, -) - - -def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: - """ - Parse stencil expression to extract the scan arguments. - - Returns - ------- - tuple(is_forward, init_carry) - The output tuple fields verify the following semantics: - - is_forward: forward boolean flag - - init_carry: carry initial value - """ - stencil_fobj = cast(FunCall, stencil) - is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) - init_carry = stencil_fobj.args[2] - assert isinstance(init_carry, Literal) - return is_forward.value == "True", init_carry - - -def _get_scan_dim( - column_axis: Dimension, - storage_types: dict[str, ts.TypeSpec], - output: SymRef, - use_field_canonical_representation: bool, -) -> tuple[str, int, ts.ScalarType]: - """ - Extract information about the scan dimension. - - Returns - ------- - tuple(scan_dim_name, scan_dim_index, scan_dim_dtype) - The output tuple fields verify the following semantics: - - scan_dim_name: name of the scan dimension - - scan_dim_index: domain index of the scan dimension - - scan_dim_dtype: data type along the scan dimension - """ - output_type = storage_types[output.id] - assert isinstance(output_type, ts.FieldType) - sorted_dims = [ - dim - for _, dim in ( - dace_utils.get_sorted_dims(output_type.dims) - if use_field_canonical_representation - else enumerate(output_type.dims) - ) - ] - return (column_axis.value, sorted_dims.index(column_axis), output_type.dtype) - - -def _make_array_shape_and_strides( - name: str, - dims: Sequence[Dimension], - offset_provider_type: common.OffsetProviderType, - sort_dims: bool, -) -> tuple[list[dace.symbol], list[dace.symbol]]: - """ - Parse field dimensions and allocate symbols for array shape and strides. - - For local dimensions, the size is known at compile-time and therefore - the corresponding array shape dimension is set to an integer literal value. - - Returns - ------- - tuple(shape, strides) - The output tuple fields are arrays of dace symbolic expressions. - """ - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) - shape = [ - ( - connectivity_types[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) - ) - for i, dim in sorted_dims - ] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i, _ in sorted_dims - ] - return shape, strides - - -def _check_no_lifts(node: itir.StencilClosure): - """ - Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. - - Returns - ------- - True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. - """ - neighbors_call_count = 0 - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): - if getattr(fun, "id", "") == "neighbors": - neighbors_call_count = 3 - elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: - return False - neighbors_call_count = max(0, neighbors_call_count - 1) - return True - - -class ItirToSDFG(eve.NodeVisitor): - param_types: list[ts.TypeSpec] - storage_types: dict[str, ts.TypeSpec] - column_axis: Optional[Dimension] - offset_provider_type: common.OffsetProviderType - unique_id: int - use_field_canonical_representation: bool - - def __init__( - self, - param_types: list[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - tmps: list[itir.Temporary], - use_field_canonical_representation: bool, - column_axis: Optional[Dimension] = None, - ): - self.param_types = param_types - self.column_axis = column_axis - self.offset_provider_type = offset_provider_type - self.storage_types = {} - self.tmps = tmps - self.use_field_canonical_representation = use_field_canonical_representation - - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): - if isinstance(type_, ts.FieldType): - shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider_type, sort_dimensions - ) - dtype = dace_utils.as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - elif isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: - sdfg.add_symbol(name, dtype) - - else: - raise NotImplementedError() - self.storage_types[name] = type_ - - def add_storage_for_temporaries( - self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG - ) -> dict[str, str]: - symbol_map: dict[str, TaskletExpr] = {} - # The shape of temporary arrays might be defined based on scalar values passed as program arguments. - # Here we collect these values in a symbol map. - for sym in node_params: - if isinstance(sym.type, ts.ScalarType): - name_ = str(sym.id) - symbol_map[name_] = SymbolExpr(name_, dace_utils.as_dace_type(sym.type)) - - tmp_symbols: dict[str, str] = {} - for tmp in self.tmps: - tmp_name = str(tmp.id) - - # We visit the domain of the temporary field, passing the set of available symbols. - assert isinstance(tmp.domain, itir.FunCall) - domain_ctx = Context(program_sdfg, defs_state, symbol_map) - tmp_domain = self._visit_domain(tmp.domain, domain_ctx) - - if isinstance(tmp.type, ts.TupleType): - raise NotImplementedError("Temporaries of tuples are not supported.") - assert isinstance(tmp.type, ts.FieldType) and isinstance(tmp.dtype, ts.ScalarType) - - # We store the FieldType for this temporary array. - self.storage_types[tmp_name] = tmp.type - - # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. - # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics - # to assign optimal stride values. - tmp_shape, _ = new_array_symbols(tmp_name, len(tmp.type.dims)) - _, tmp_array = program_sdfg.add_array( - tmp_name, tmp_shape, dace_utils.as_dace_type(tmp.dtype), transient=True - ) - - # Loop through all dimensions to visit the symbolic expressions for array shape and offset. - # These expressions are later mapped to interstate symbols. - for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): - # The temporary field has a dimension range defined by `begin` and `end` values. - # Therefore, the actual size is given by the difference `end.value - begin.value`. - # Instead of allocating the actual size, we allocate space to enable indexing from 0 - # because we want to avoid using dace array offsets (which will be deprecated soon). - # The result should still be valid, but the stencil will be using only a subset - # of the array. - if not (isinstance(begin, SymbolExpr) and begin.value == "0"): - warnings.warn( - f"Domain start offset for temporary {tmp_name} is ignored.", stacklevel=2 - ) - tmp_symbols[str(shape_sym)] = end.value - - return tmp_symbols - - def create_memlet_at(self, field_name: str, index: dict[str, str]): - field_type = self.storage_types[field_name] - assert isinstance(field_type, ts.FieldType) - if self.use_field_canonical_representation: - field_index = [ - index[dim.value] for _, dim in dace_utils.get_sorted_dims(field_type.dims) - ] - else: - field_index = [index[dim.value] for dim in field_type.dims] - subset = ", ".join(field_index) - return dace.Memlet(data=field_name, subset=subset) - - def get_output_nodes( - self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dict[str, dace.nodes.AccessNode]: - # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes - output_symbols_pass = GatherOutputSymbolsPass(sdfg, state) - output_symbols_pass.visit(closure.output) - # Visit output node again to generate the corresponding tasklet - context = Context(sdfg, state, output_symbols_pass.symbol_refs) - translator = PythonTaskletCodegen( - self.offset_provider_type, context, self.use_field_canonical_representation - ) - output_nodes = flatten_list(translator.visit(closure.output)) - return {node.value.data: node.value for node in output_nodes} - - def visit_FencilDefinition(self, node: itir.FencilDefinition): - program_sdfg = dace.SDFG(name=node.id) - program_sdfg.debuginfo = dace_utils.debug_info(node) - entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - - # Filter neighbor tables from offset providers. - connectivity_types = get_used_connectivities(node, self.offset_provider_type) - - # Add program parameters as SDFG storages. - for param, type_ in zip(node.params, self.param_types): - self.add_storage( - program_sdfg, str(param.id), type_, self.use_field_canonical_representation - ) - - if self.tmps: - tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) - # on the first interstate edge define symbols for shape and offsets of temporary arrays - last_state = program_sdfg.add_state("init_symbols_for_temporaries") - program_sdfg.add_edge( - entry_state, last_state, dace.InterstateEdge(assignments=tmp_symbols) - ) - else: - last_state = entry_state - - # Add connectivities as SDFG storages. - for offset, connectivity_type in connectivity_types.items(): - scalar_type = tt.from_dtype(connectivity_type.dtype) - type_ = ts.FieldType( - [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type - ) - self.add_storage( - program_sdfg, - dace_utils.connectivity_identifier(offset), - type_, - sort_dimensions=False, - ) - - # Create a nested SDFG for all stencil closures. - for closure in node.closures: - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg, input_names, output_names = self.visit( - closure, array_table=program_sdfg.arrays - ) - - # Create a new state for the closure. - last_state = program_sdfg.add_state_after(last_state) - - # Create memlets to transfer the program parameters - input_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in input_names - } - output_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in output_names - } - - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) - - # Insert the closure's SDFG as a nested SDFG of the program. - nsdfg_node = last_state.add_nested_sdfg( - sdfg=closure_sdfg, - parent=program_sdfg, - inputs=set(input_names), - outputs=set(output_names), - symbol_mapping=symbol_mapping, - debuginfo=closure_sdfg.debuginfo, - ) - - # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - - for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) - - # Create the call signature for the SDFG. - # Only the arguments requiered by the Fencil, i.e. `node.params` are added as positional arguments. - # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. - program_sdfg.arg_names = [str(a) for a in node.params] - - program_sdfg.validate() - return program_sdfg - - def visit_StencilClosure( - self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> tuple[dace.SDFG, list[str], list[str]]: - assert _check_no_lifts(node) - - # Create the closure's nested SDFG and single state. - closure_sdfg = dace.SDFG(name="closure") - closure_sdfg.debuginfo = dace_utils.debug_info(node) - closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) - output_names = [k for k, _ in output_nodes.items()] - - # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. - input_transients_mapping = {} - for name in [*input_names, *connectivity_names, *output_names]: - if name in closure_sdfg.arrays: - assert name in input_names and name in output_names - # In case of closures with in/out fields, there is risk of race condition - # between read/write access nodes in the (asynchronous) map tasklet. - transient_name = unique_var_name() - closure_sdfg.add_array( - transient_name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - transient=True, - ) - closure_init_state.add_nedge( - closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), - closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), - dace.Memlet.from_array(name, closure_sdfg.arrays[name]), - ) - input_transients_mapping[name] = transient_name - elif isinstance(self.storage_types[name], ts.FieldType): - closure_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - assert isinstance(self.storage_types[name], ts.ScalarType) - - input_field_names = [ - input_name - for input_name in input_names - if isinstance(self.storage_types[input_name], ts.FieldType) - ] - - # Closure outputs should all be fields - assert all( - isinstance(self.storage_types[output_name], ts.FieldType) - for output_name in output_names - ) - - # Update symbol table and get output domain of the closure - program_arg_syms: dict[str, TaskletExpr] = {} - for name, type_ in self.storage_types.items(): - if isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in input_names: - out_name = unique_var_name() - closure_sdfg.add_scalar(out_name, dtype, transient=True) - out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", - {}, - {"__result"}, - f"__result = {name}", - debuginfo=closure_sdfg.debuginfo, - ) - access = closure_init_state.add_access( - out_name, debuginfo=closure_sdfg.debuginfo - ) - value = ValueExpr(access, dtype) - memlet = dace.Memlet(data=out_name, subset="0") - closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) - program_arg_syms[name] = value - else: - program_arg_syms[name] = SymbolExpr(name, dtype) - else: - assert isinstance(type_, ts.FieldType) - # make shape symbols (corresponding to field size) available as arguments to domain visitor - if name in input_names or name in output_names: - field_symbols = [ - val - for val in closure_sdfg.arrays[name].shape - if isinstance(val, dace.symbol) and str(val) not in input_names - ] - for sym in field_symbols: - sym_name = str(sym) - program_arg_syms[sym_name] = SymbolExpr(sym, sym.dtype) - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, closure_ctx) - - # Map SDFG tasklet arguments to parameters - input_local_names = [ - ( - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else ( - input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data - ) - ) - for input_name in input_names - ] - input_memlets = [ - dace.Memlet.from_array(name, closure_sdfg.arrays[name]) - for name in [*input_local_names, *connectivity_names] - ] - - # create and write to transient that is then copied back to actual output array to avoid aliasing of - # same memory in nested SDFG with different names - output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} - # scan operator should always be the first function call in a closure - if is_scan(node.stencil): - assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" - transient_name, output_name = next(iter(output_connectors_mapping.items())) - - nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, transient_name - ) - results = [transient_name] - - _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] - output_subset = f"{scan_lb.value}:{scan_ub.value}" - - domain_subset = { - dim: ( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" - ) - for dim, _ in closure_domain - } - output_memlets = [self.create_memlet_at(output_name, domain_subset)] - else: - nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( - node, closure_sdfg.arrays, closure_domain - ) - - output_subset = "0" - - output_memlets = [ - self.create_memlet_at(output_name, {dim: f"i_{dim}" for dim, _ in closure_domain}) - for output_name in output_connectors_mapping.values() - ] - - input_mapping = { - param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) - } - output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) - - nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( - closure_state, - sdfg=nsdfg, - map_ranges=map_ranges or {"__dummy": "0"}, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - output_nodes=output_nodes, - debuginfo=nsdfg.debuginfo, - ) - access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} - for edge in closure_state.in_edges(map_exit): - memlet = edge.data - if memlet.data not in output_connectors_mapping: - continue - transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) - closure_state.add_edge( - nsdfg_node, - edge.src_conn, - transient_access, - None, - dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo), - ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset - ) - closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) - closure_state.remove_edge(edge) - access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - - return closure_sdfg, input_field_names + connectivity_names, output_names - - def _visit_scan_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - output_name: str, - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: - # extract scan arguments - is_forward, init_carry_value = _get_scan_args(node.stencil) - # select the scan dimension based on program argument for column axis - assert self.column_axis - assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( - self.column_axis, - self.storage_types, - node.output, - self.use_field_canonical_representation, - ) - - assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - if not dim == scan_dim: - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - else: - scan_lb_str = lb_str - scan_ub_str = ub_str - - # the scan operator is implemented as an SDFG to be nested in the closure SDFG - scan_sdfg = dace.SDFG(name="scan") - scan_sdfg.debuginfo = dace_utils.debug_info(node) - - # the carry value of the scan operator exists only in the scope of the scan sdfg - scan_carry_name = unique_var_name() - scan_sdfg.add_scalar( - scan_carry_name, dtype=dace_utils.as_dace_type(scan_dtype), transient=True - ) - - # create a loop region for lambda call over the scan dimension - scan_loop_var = f"i_{scan_dim}" - if is_forward: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} < {scan_ub_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_lb_str}", - update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", - inverted=False, - ) - else: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} >= {scan_lb_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", - update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", - inverted=False, - ) - scan_sdfg.add_node(scan_loop) - compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) - update_state = scan_loop.add_state("lambda_update") - scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) - - start_state = scan_sdfg.add_state("start", is_start_block=True) - scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) - - # tasklet for initialization of carry - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", - {}, - {"__result"}, - f"__result = {init_carry_value}", - debuginfo=scan_sdfg.debuginfo, - ) - start_state.add_edge( - carry_init_tasklet, - "__result", - start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), - None, - dace.Memlet(data=scan_carry_name, subset="0"), - ) - - # add storage to scan SDFG for inputs - for name in [*input_names, *connectivity_names]: - assert name not in scan_sdfg.arrays - if isinstance(self.storage_types[name], ts.FieldType): - scan_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - scan_sdfg.add_scalar( - name, - dtype=dace_utils.as_dace_type(cast(ts.ScalarType, self.storage_types[name])), - ) - # add storage to scan SDFG for output - scan_sdfg.add_array( - output_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - - # implement the lambda function as a nested SDFG that computes a single item in the scan dimension - lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - input_arrays = [(scan_carry_name, scan_dtype)] + [ - (name, self.storage_types[name]) for name in input_names - ] - connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - lambda_context, lambda_outputs = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - lambda_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - lambda_input_names = [name for name, _ in input_arrays] - lambda_output_names = [connector.value.data for connector in lambda_outputs] - - input_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in lambda_input_names - ] - connectivity_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in connectivity_names - ] - input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - - scan_inner_node = compute_state.add_nested_sdfg( - lambda_context.body, - parent=scan_sdfg, - inputs=set(lambda_input_names) | set(connectivity_names), - outputs=set(lambda_output_names), - symbol_mapping=symbol_mapping, - debuginfo=lambda_context.body.debuginfo, - ) - - # connect scan SDFG to lambda inputs - for name, memlet in array_mapping.items(): - access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) - - output_names = [output_name] - assert len(lambda_output_names) == 1 - # connect lambda output to scan SDFG - for name, connector in zip(output_names, lambda_output_names): - compute_state.add_edge( - scan_inner_node, - connector, - compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), - None, - dace.Memlet(data=name, subset=scan_loop_var), - ) - - update_state.add_nedge( - update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), - dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"), - ) - - return scan_sdfg, map_ranges, scan_dim_index - - def _visit_parallel_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - - # Create an SDFG for the tasklet that computes a single item of the output domain. - index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - - input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in connectivity_names] - - context, results = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - index_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - return context.body, map_ranges, [r.value.data for r in results] - - def _visit_domain( - self, node: itir.FunCall, context: Context - ) -> tuple[tuple[str, tuple[SymbolExpr | ValueExpr, SymbolExpr | ValueExpr]], ...]: - assert isinstance(node.fun, itir.SymRef) - assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" - - bounds: list[tuple[str, tuple[ValueExpr, ValueExpr]]] = [] - - for named_range in node.args: - assert isinstance(named_range, itir.FunCall) - assert isinstance(named_range.fun, itir.SymRef) - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - translator = PythonTaskletCodegen( - self.offset_provider_type, - context, - self.use_field_canonical_representation, - ) - lb = translator.visit(lower_bound)[0] - ub = translator.visit(upper_bound)[0] - bounds.append((dimension.value, (lb, ub))) - - return tuple(bounds) - - @staticmethod - def _check_shift_offsets_are_literals(node: itir.StencilClosure): - fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) - shifts = [nd for nd in fun_calls if getattr(nd.fun, "id", "") == "shift"] - for shift in shifts: - if not all(isinstance(arg, (itir.Literal, itir.OffsetLiteral)) for arg in shift.args): - return False - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py deleted file mode 100644 index 2b2669187a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ /dev/null @@ -1,1564 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -import dataclasses -import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, TypeAlias, cast - -import dace -import numpy as np - -import gt4py.eve.codegen -from gt4py import eve -from gt4py.next import common -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import FunCall, Lambda -from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_name, - unique_var_name, -) - - -_TYPE_MAPPING = { - "float": dace.float64, - "float32": dace.float32, - "float64": dace.float64, - "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, - "int32": dace.int32, - "int64": dace.int64, - "bool": dace.bool_, -} - - -def itir_type_as_dace_type(type_: ts.TypeSpec): - # TODO(tehrengruber): this function just converts the scalar type of whatever it is given, - # let it be a field, iterator, or directly a scalar. The caller should take care of the - # extraction. - dtype: ts.TypeSpec - if isinstance(type_, ts.FieldType): - dtype = type_.dtype - elif isinstance(type_, it_ts.IteratorType): - dtype = type_.element_type - else: - dtype = type_ - assert isinstance(dtype, ts.ScalarType) - return _TYPE_MAPPING[dtype.kind.name.lower()] - - -def get_reduce_identity_value(op_name_: str, type_: Any): - if op_name_ == "plus": - init_value = type_(0) - elif op_name_ == "multiplies": - init_value = type_(1) - elif op_name_ == "minimum": - init_value = type_("inf") - elif op_name_ == "maximum": - init_value = type_("-inf") - else: - raise NotImplementedError() - - return init_value - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -# Define type of variables used for field indexing -_INDEX_DTYPE = _TYPE_MAPPING["int64"] - - -@dataclasses.dataclass -class SymbolExpr: - value: dace.symbolic.SymbolicType - dtype: dace.typeclass - - -@dataclasses.dataclass -class ValueExpr: - value: dace.nodes.AccessNode - dtype: dace.typeclass - - -@dataclasses.dataclass -class IteratorExpr: - field: dace.nodes.AccessNode - indices: dict[str, dace.nodes.AccessNode] - dtype: dace.typeclass - dimensions: list[str] - - -# Union of possible expression types -TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr - - -@dataclasses.dataclass -class Context: - body: dace.SDFG - state: dace.SDFGState - symbol_map: dict[str, TaskletExpr] - # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_identity: Optional[SymbolExpr] - - def __init__( - self, - body: dace.SDFG, - state: dace.SDFGState, - symbol_map: dict[str, TaskletExpr], - reduce_identity: Optional[SymbolExpr] = None, - ): - self.body = body - self.state = state - self.symbol_map = symbol_map - self.reduce_identity = reduce_identity - - -def _visit_lift_in_neighbors_reduction( - transformer: PythonTaskletCodegen, - node: itir.FunCall, - node_args: Sequence[IteratorExpr | list[ValueExpr]], - connectivity_type: common.NeighborConnectivityType, - map_entry: dace.nodes.MapEntry, - map_exit: dace.nodes.MapExit, - neighbor_index_node: dace.nodes.AccessNode, - neighbor_value_node: dace.nodes.AccessNode, -) -> list[ValueExpr]: - assert transformer.context.reduce_identity is not None - neighbor_dim = connectivity_type.codomain.value - origin_dim = connectivity_type.source_dim.value - - lifted_args: list[IteratorExpr | ValueExpr] = [] - for arg in node_args: - if isinstance(arg, IteratorExpr): - if origin_dim in arg.indices: - lifted_indices = arg.indices.copy() - lifted_indices.pop(origin_dim) - lifted_indices[neighbor_dim] = neighbor_index_node - lifted_args.append( - IteratorExpr(arg.field, lifted_indices, arg.dtype, arg.dimensions) - ) - else: - lifted_args.append(arg) - else: - lifted_args.append(arg[0]) - - lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) - assert len(inner_outputs) == 1 - inner_out_connector = inner_outputs[0].value.data - - input_nodes = {} - iterator_index_nodes = {} - lifted_index_connectors = [] - - for x, y in inner_inputs: - if isinstance(y, IteratorExpr): - field_connector, inner_index_table = x - input_nodes[field_connector] = y.field - for dim, connector in inner_index_table.items(): - if dim == neighbor_dim: - lifted_index_connectors.append(connector) - iterator_index_nodes[connector] = y.indices[dim] - else: - assert isinstance(y, ValueExpr) - input_nodes[x] = y.value - - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - parent_sdfg = transformer.context.body - parent_state = transformer.context.state - - input_mapping = { - connector: dace.Memlet.from_array(node.data, node.desc(parent_sdfg)) - for connector, node in input_nodes.items() - } - connectivity_mapping = { - name: dace.Memlet.from_array(name, parent_sdfg.arrays[name]) for name in connectivity_names - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) - - nested_sdfg_node = parent_state.add_nested_sdfg( - lift_context.body, - parent_sdfg, - inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, - outputs={inner_out_connector}, - symbol_mapping=symbol_mapping, - debuginfo=lift_context.body.debuginfo, - ) - - for connectivity_connector, memlet in connectivity_mapping.items(): - parent_state.add_memlet_path( - parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), - map_entry, - nested_sdfg_node, - dst_conn=connectivity_connector, - memlet=memlet, - ) - - for inner_connector, access_node in input_nodes.items(): - parent_state.add_memlet_path( - access_node, - map_entry, - nested_sdfg_node, - dst_conn=inner_connector, - memlet=input_mapping[inner_connector], - ) - - for inner_connector, access_node in iterator_index_nodes.items(): - memlet = dace.Memlet(data=access_node.data, subset="0") - if inner_connector in lifted_index_connectors: - parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) - else: - parent_state.add_memlet_path( - access_node, map_entry, nested_sdfg_node, dst_conn=inner_connector, memlet=memlet - ) - - parent_state.add_memlet_path( - nested_sdfg_node, - map_exit, - neighbor_value_node, - src_conn=inner_out_connector, - memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), - ) - - if connectivity_type.has_skip_values: - # check neighbor validity on if/else inter-state edge - # use one branch for connectivity case - start_state = lift_context.body.add_state_before( - lift_context.body.start_state, - "start", - condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}", - ) - # use the other branch for skip value case - skip_neighbor_state = lift_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), - None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), - ) - lift_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), - ) - - return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] - - -def builtin_neighbors( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - sdfg: dace.SDFG = transformer.context.body - state: dace.SDFGState = transformer.context.state - - di = dace_utils.debug_info(node, default=sdfg.debuginfo) - offset_literal, data = node_args - assert isinstance(offset_literal, itir.OffsetLiteral) - offset_dim = offset_literal.value - assert isinstance(offset_dim, str) - connectivity_type = transformer.offset_provider_type[offset_dim] - if not isinstance(connectivity_type, common.NeighborConnectivityType): - raise NotImplementedError( - "Neighbor reduction only implemented for connectivity based on neighbor tables." - ) - - lift_node = None - if isinstance(data, FunCall): - assert isinstance(data.fun, itir.FunCall) - fun_node = data.fun - if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": - lift_node = fun_node - lift_args = transformer.visit(data.args) - iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) - if lift_node is None: - iterator = transformer.visit(data) - assert isinstance(iterator, IteratorExpr) - field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[connectivity_type.source_dim.value] - - assert transformer.context.reduce_identity is not None - assert transformer.context.reduce_identity.dtype == iterator.dtype - - # gather the neighbors in a result array dimensioned for `max_neighbors` - neighbor_value_var = unique_var_name() - sdfg.add_array( - neighbor_value_var, - dtype=iterator.dtype, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) - - # allocate scalar to store index for direct addressing of neighbor field - neighbor_index_var = unique_var_name() - sdfg.add_scalar(neighbor_index_var, _INDEX_DTYPE, transient=True) - neighbor_index_node = state.add_access(neighbor_index_var, debuginfo=di) - - # generate unique map index name to avoid conflict with other maps inside same state - neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") - me, mx = state.add_map( - f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, - debuginfo=di, - ) - - table_name = dace_utils.connectivity_identifier(offset_dim) - shift_tasklet = state.add_tasklet( - "shift", - code=f"__result = __table[__idx, {neighbor_map_index}]", - inputs={"__table", "__idx"}, - outputs={"__result"}, - debuginfo=di, - ) - state.add_memlet_path( - state.add_access(table_name, debuginfo=di), - me, - shift_tasklet, - memlet=dace.Memlet.from_array(table_name, sdfg.arrays[table_name]), - dst_conn="__table", - ) - state.add_memlet_path( - origin_index_node, - me, - shift_tasklet, - memlet=dace.Memlet(data=origin_index_node.data, subset="0"), - dst_conn="__idx", - ) - state.add_edge( - shift_tasklet, - "__result", - neighbor_index_node, - None, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - - if lift_node is not None: - _visit_lift_in_neighbors_reduction( - transformer, - lift_node, - lift_args, - connectivity_type, - me, - mx, - neighbor_index_node, - neighbor_value_node, - ) - else: - sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) - data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" - data_access_tasklet = state.add_tasklet( - "data_access", - code=f"__data = __field[{data_access_index}] " - + ( - f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if connectivity_type.has_skip_values - else "" - ), - inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, - outputs={"__data"}, - debuginfo=di, - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=dace.Memlet.from_array(iterator.field.data, field_desc), - dst_conn="__field", - ) - for dim in iterator.dimensions: - connector = f"{dim}_v" - if dim == connectivity_type.codomain.value: - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - connector, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - else: - state.add_memlet_path( - iterator.indices[dim], - me, - data_access_tasklet, - dst_conn=connector, - memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"), - ) - - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index), - src_conn="__data", - ) - - if not connectivity_type.has_skip_values: - return [ValueExpr(neighbor_value_node, iterator.dtype)] - else: - """ - In case of neighbor tables with skip values, in addition to the array of neighbor values this function also - returns an array of booleans to indicate if the neighbor value is present or not. This node is only used - for neighbor reductions with lambda functions, a very specific case. For single input neighbor reductions, - the regular case, this node will be removed by the simplify pass. - """ - neighbor_valid_var = unique_var_name() - sdfg.add_array( - neighbor_valid_var, - dtype=dace.dtypes.bool, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) - - neighbor_valid_tasklet = state.add_tasklet( - f"check_valid_neighbor_{offset_dim}", - {"__idx"}, - {"__valid"}, - f"__valid = True if __idx != {neighbor_skip_value} else False", - debuginfo=di, - ) - state.add_edge( - neighbor_index_node, - None, - neighbor_valid_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - neighbor_valid_tasklet, - mx, - neighbor_valid_node, - memlet=dace.Memlet(data=neighbor_valid_var, subset=neighbor_map_index), - src_conn="__valid", - ) - return [ - ValueExpr(neighbor_value_node, iterator.dtype), - ValueExpr(neighbor_valid_node, dace.dtypes.bool), - ] - - -def builtin_can_deref( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - # first visit shift, to get set of indices for deref - can_deref_callable = node_args[0] - assert isinstance(can_deref_callable, itir.FunCall) - shift_callable = can_deref_callable.fun - assert isinstance(shift_callable, itir.FunCall) - assert isinstance(shift_callable.fun, itir.SymRef) - assert shift_callable.fun.id == "shift" - iterator = transformer._visit_shift(can_deref_callable) - - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it - if not isinstance(iterator, IteratorExpr): - assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) - # We can always deref a value expression, therefore hard-code `can_deref` to True. - # Returning a SymbolExpr would be preferable, but it requires update to type-checking. - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name, debuginfo=di) - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di - ), - "_out", - result_node, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_node, dace.dtypes.bool)] - - # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] - internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", dace_debuginfo=di - ) - - -def builtin_if( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - assert len(node_args) == 3 - sdfg = transformer.context.body - current_state = transformer.context.state - is_start_state = sdfg.start_block == current_state - - # build an empty state to join true and false branches - join_state = sdfg.add_state_before(current_state, "join") - - def build_if_state(arg, state): - symbol_map = copy.deepcopy(transformer.context.symbol_map) - node_context = Context(sdfg, state, symbol_map) - node_taskgen = PythonTaskletCodegen( - transformer.offset_provider_type, - node_context, - transformer.use_field_canonical_representation, - ) - return node_taskgen.visit(arg) - - # represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state - stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state) - stmt_node = build_if_state(node_args[0], stmt_state)[0] - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - assert sdfg.arrays[stmt_node.value.data].shape == (1,) - - # visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state - tbr_state = sdfg.add_state("true_branch") - sdfg.add_edge( - stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") - ) - sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) - tbr_values = flatten_list(build_if_state(node_args[1], tbr_state)) - # - fbr_state = sdfg.add_state("false_branch") - sdfg.add_edge( - stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") - ) - sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) - fbr_values = flatten_list(build_if_state(node_args[2], fbr_state)) - - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - # make the result of the if-statement evaluation available inside current state - ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype) - - # we distinguish between select if-statements, where both true and false branches are symbolic expressions, - # and therefore do not require exclusive branch execution, and regular if-statements where at least one branch - # is a value expression, which has to be evaluated at runtime with conditional state transition - result_values = [] - assert len(tbr_values) == len(fbr_values) - for tbr_value, fbr_value in zip(tbr_values, fbr_values): - assert isinstance(tbr_value, (SymbolExpr, ValueExpr)) - assert isinstance(fbr_value, (SymbolExpr, ValueExpr)) - assert tbr_value.dtype == fbr_value.dtype - - if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)): - # both branches return symbolic expressions, therefore the if-node can be translated - # to a select-tasklet inside current state - # TODO: use select-memlet when it becomes available in dace - code = f"{tbr_value.value} if _cond else {fbr_value.value}" - if_expr = transformer.add_expr_tasklet( - [(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select" - )[0] - result_values.append(if_expr) - else: - # at least one of the two branches contains a value expression, which should be evaluated - # only if the corresponding true/false condition is satisfied - desc = sdfg.arrays[ - tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data - ] - var = unique_var_name() - if isinstance(desc, dace.data.Scalar): - sdfg.add_scalar(var, desc.dtype, transient=True) - else: - sdfg.add_array(var, desc.shape, desc.dtype, transient=True) - - # write result to transient data container and access it in the original state - for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]: - val_node = state.add_access(var) - if isinstance(expr, ValueExpr): - state.add_nedge( - expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc) - ) - else: - assert desc.shape == (1,) - state.add_edge( - state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"), - "_out", - val_node, - None, - dace.Memlet(var, "0"), - ) - result_values.append(ValueExpr(current_state.add_access(var), desc.dtype)) - - if tbr_state.is_empty() and fbr_state.is_empty(): - # if all branches are symbolic expressions, the true/false and join states can be removed - # as well as the conditional state transition - sdfg.remove_nodes_from([join_state, tbr_state, fbr_state]) - sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge()) - elif tbr_state.is_empty(): - # use direct edge from if-statement to join state for true branch - tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition - sdfg.remove_node(tbr_state) - elif fbr_state.is_empty(): - # use direct edge from if-statement to join state for false branch - fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition - sdfg.remove_node(fbr_state) - else: - # remove direct edge from if-statement to join state - sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0]) - # the if-statement condition is not used in current state - current_state.remove_node(ctx_stmt_node.value) - - return result_values - - -def builtin_list_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = list(itertools.chain(*transformer.visit(node_args))) - assert len(args) == 2 - # index node - if isinstance(args[0], SymbolExpr): - index_value = args[0].value - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) - result_node = transformer.context.state.add_access(result_name) - transformer.context.state.add_nedge( - args[1].value, result_node, dace.Memlet(data=args[1].value.data, subset=index_value) - ) - return [ValueExpr(result_node, args[1].dtype)] - - else: - expr_args = [(arg, f"{arg.value.data}_v") for arg in args] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) - - -def builtin_cast( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = transformer.visit(node_args[0]) - internals = [f"{arg.value.data}_v" for arg in args] - target_type = node_args[1] - assert isinstance(target_type, itir.SymRef) - expr = _MATH_BUILTINS_MAPPING[target_type.id].format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di - ) - - -def builtin_make_const_list( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = [transformer.visit(arg)[0] for arg in node_args] - assert all(isinstance(x, (SymbolExpr, ValueExpr)) for x in args) - args_dtype = [x.dtype for x in args] - assert len(set(args_dtype)) == 1 - dtype = args_dtype[0] - - var_name = unique_var_name() - transformer.context.body.add_array(var_name, (len(args),), dtype, transient=True) - var_node = transformer.context.state.add_access(var_name, debuginfo=di) - - for i, arg in enumerate(args): - if isinstance(arg, SymbolExpr): - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - f"get_arg{i}", {}, {"val"}, f"val = {arg.value}" - ), - "val", - var_node, - None, - dace.Memlet(data=var_name, subset=f"{i}"), - ) - else: - assert arg.value.desc(transformer.context.body).shape == (1,) - transformer.context.state.add_nedge( - arg.value, - var_node, - dace.Memlet(data=arg.value.data, subset="0", other_subset=f"{i}"), - ) - - return [ValueExpr(var_node, dtype)] - - -def builtin_make_tuple( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - args = [transformer.visit(arg) for arg in node_args] - return args - - -def builtin_tuple_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - elements = transformer.visit(node_args[1]) - index = node_args[0] - if isinstance(index, itir.Literal): - return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants.") - - -_GENERAL_BUILTIN_MAPPING: dict[ - str, Callable[[PythonTaskletCodegen, itir.Expr, list[itir.Expr]], list[ValueExpr]] -] = { - "can_deref": builtin_can_deref, - "cast_": builtin_cast, - "if_": builtin_if, - "list_get": builtin_list_get, - "make_const_list": builtin_make_const_list, - "make_tuple": builtin_make_tuple, - "neighbors": builtin_neighbors, - "tuple_get": builtin_tuple_get, -} - - -class GatherLambdaSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] - _parent_symbol_map: dict[str, TaskletExpr] - - def __init__(self, sdfg, state, parent_symbol_map): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - self._parent_symbol_map = parent_symbol_map - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the lambda expression.""" - return self._symbol_map - - def _add_symbol(self, param, arg): - if isinstance(arg, ValueExpr): - # create storage in lambda sdfg - self._sdfg.add_scalar(param, dtype=arg.dtype) - # update table of lambda symbols - self._symbol_map[param] = ValueExpr( - self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype - ) - elif isinstance(arg, IteratorExpr): - # create storage in lambda sdfg - ndims = len(arg.dimensions) - shape, strides = new_array_symbols(param, ndims) - self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) - index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) - # update table of lambda symbols - field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - indices = { - dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) - for dim, index_arg in index_names.items() - } - self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - else: - assert isinstance(arg, SymbolExpr) - self._symbol_map[param] = arg - - def _add_tuple(self, param, args): - nodes = [] - # create storage in lambda sdfg for each tuple element - for arg in args: - var = unique_var_name() - self._sdfg.add_scalar(var, dtype=arg.dtype) - arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo) - nodes.append(ValueExpr(arg_node, arg.dtype)) - # update table of lambda symbols - self._symbol_map[param] = tuple(nodes) - - def visit_SymRef(self, node: itir.SymRef): - name = str(node.id) - if name in self._parent_symbol_map and name not in self._symbol_map: - arg = self._parent_symbol_map[name] - self._add_symbol(name, arg) - - def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): - if args is not None: - if len(node.params) == len(args): - for param, arg in zip(node.params, args): - self._add_symbol(str(param.id), arg) - else: - # implicitly make tuple - assert len(node.params) == 1 - self._add_tuple(str(node.params[0].id), args) - self.visit(node.expr) - - -class GatherOutputSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr] - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the output expression.""" - return self._symbol_map - - def __init__(self, sdfg, state): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - - def visit_SymRef(self, node: itir.SymRef): - param = str(node.id) - if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: - access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - self._symbol_map[param] = ValueExpr( - access_node, - dtype=itir_type_as_dace_type(node.type), # type: ignore[arg-type] # ensure by type inference - ) - - -@dataclasses.dataclass -class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider_type: common.OffsetProviderType - context: Context - use_field_canonical_representation: bool - - def get_sorted_field_dimensions(self, dims: Sequence[str]): - return sorted(dims) if self.use_field_canonical_representation else dims - - def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): - raise NotImplementedError() - - def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True - ) -> tuple[ - Context, - list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], - list[ValueExpr], - ]: - func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = ( - get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} - ) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # Create the SDFG for the lambda's body - lambda_sdfg = dace.SDFG(func_name) - lambda_sdfg.debuginfo = dace_utils.debug_info(node, default=self.context.body.debuginfo) - lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True) - - lambda_symbols_pass = GatherLambdaSymbolsPass( - lambda_sdfg, lambda_state, self.context.symbol_map - ) - lambda_symbols_pass.visit(node, args=args) - - # Add for input nodes for lambda symbols - inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for sym, input_node in lambda_symbols_pass.symbol_refs.items(): - params = [str(p.id) for p in node.params] - try: - param_index = params.index(sym) - except ValueError: - param_index = -1 - if param_index >= 0: - outer_node = args[param_index] - else: - # the symbol is not found among lambda arguments, then it is inherited from parent scope - outer_node = self.context.symbol_map[sym] - if isinstance(input_node, IteratorExpr): - assert isinstance(outer_node, IteratorExpr) - index_params = { - dim: index_node.data for dim, index_node in input_node.indices.items() - } - inputs.append(((sym, index_params), outer_node)) - elif isinstance(input_node, ValueExpr): - assert isinstance(outer_node, ValueExpr) - inputs.append((sym, outer_node)) - elif isinstance(input_node, tuple): - assert param_index >= 0 - for i, input_node_i in enumerate(input_node): - arg_i = args[param_index + i] - assert isinstance(arg_i, ValueExpr) - assert isinstance(input_node_i, ValueExpr) - inputs.append((input_node_i.value.data, arg_i)) - - # Add connectivities as arrays - for name in connectivity_names: - shape, strides = new_array_symbols(name, ndim=2) - dtype = self.context.body.arrays[name].dtype - lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - # Translate the lambda's body in its own context - lambda_context = Context( - lambda_sdfg, - lambda_state, - lambda_symbols_pass.symbol_refs, - reduce_identity=self.context.reduce_identity, - ) - lambda_taskgen = PythonTaskletCodegen( - self.offset_provider_type, - lambda_context, - self.use_field_canonical_representation, - ) - - results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lambda - # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - node.expr.location = node.location - for expr in flatten_list(lambda_taskgen.visit(node.expr)): - if isinstance(expr, ValueExpr): - result_name = unique_var_name() - lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access( - result_name, debuginfo=lambda_sdfg.debuginfo - ) - lambda_state.add_nedge( - expr.value, result_access, dace.Memlet(data=result_access.data, subset="0") - ) - result = ValueExpr(value=result_access, dtype=expr.dtype) - else: - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet( - [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo - )[0] - lambda_sdfg.arrays[result.value.data].transient = False - results.append(result) - - # remove isolated access nodes for connectivity arrays not consumed by lambda - for sub_node in lambda_state.nodes(): - if isinstance(sub_node, dace.nodes.AccessNode): - if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: - lambda_state.remove_node(sub_node) - - return lambda_context, inputs, results - - def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - param = str(node.id) - value = self.context.symbol_map[param] - if isinstance(value, (ValueExpr, SymbolExpr)): - return [value] - return value - - def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: - return [SymbolExpr(node.value, itir_type_as_dace_type(node.type))] - - def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: - node.fun.location = node.location - if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": - return self._visit_deref(node) - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "shift": - return self._visit_shift(node) - elif node.fun.fun.id == "reduce": - return self._visit_reduce(node) - - if isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - elif builtin_name in _GENERAL_BUILTIN_MAPPING: - return self._visit_general_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - return self._visit_call(node) - - def _visit_call(self, node: itir.FunCall): - args = self.visit(node.args) - args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] - args = list(itertools.chain(*args)) - node.fun.location = node.location - func_context, func_inputs, results = self.visit(node.fun, args=args) - - nsdfg_inputs = {} - for name, value in func_inputs: - if isinstance(value, ValueExpr): - nsdfg_inputs[name] = dace.Memlet.from_array( - value.value.data, self.context.body.arrays[value.value.data] - ) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - nsdfg_inputs[field] = dace.Memlet.from_array( - value.field.data, self.context.body.arrays[value.field.data] - ) - for dim, var in indices.items(): - store = value.indices[dim].data - nsdfg_inputs[var] = dace.Memlet.from_array( - store, self.context.body.arrays[store] - ) - - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) - - symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - - nsdfg_node = self.context.state.add_nested_sdfg( - func_context.body, - None, - inputs=set(nsdfg_inputs.keys()), - outputs=set(r.value.data for r in results), - symbol_mapping=symbol_mapping, - debuginfo=dace_utils.debug_info(node, default=func_context.body.debuginfo), - ) - - for name, value in func_inputs: - if isinstance(value, ValueExpr): - value_memlet = nsdfg_inputs[name] - self.context.state.add_edge(value.value, None, nsdfg_node, name, value_memlet) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - field_memlet = nsdfg_inputs[field] - self.context.state.add_edge(value.field, None, nsdfg_node, field, field_memlet) - for dim, var in indices.items(): - store = value.indices[dim] - idx_memlet = nsdfg_inputs[var] - self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) - self.context.state.add_edge(access, None, nsdfg_node, var, memlet) - - result_exprs = [] - for result in results: - name = unique_var_name() - self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) - result_exprs.append(ValueExpr(result_access, result.dtype)) - memlet = dace.Memlet.from_array(name, self.context.body.arrays[name]) - self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) - - return result_exprs - - def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # already a list of ValueExpr - return iterator - - sorted_dims = self.get_sorted_field_dimensions(iterator.dimensions) - if all([dim in iterator.indices for dim in iterator.dimensions]): - # The deref iterator has index values on all dimensions: the result will be a scalar - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet( - list(zip(args, internals)), expr, iterator.dtype, "deref", dace_debuginfo=di - ) - - else: - dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] - assert len(dims_not_indexed) == 1 - offset = dims_not_indexed[0] - offset_provider_type = self.offset_provider_type[offset] - assert isinstance(offset_provider_type, common.NeighborConnectivityType) - neighbor_dim = offset_provider_type.codomain.value - - result_name = unique_var_name() - self.context.body.add_array( - result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True - ) - result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name, debuginfo=di) - - deref_connectors = ["_inp"] + [ - f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices - ] - deref_nodes = [iterator.field] + [ - iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices - ] - deref_memlets = [ - dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) - ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] - - # we create a mapped tasklet for array slicing - index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} - src_subset = ",".join( - [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] - ) - self.context.state.add_mapped_tasklet( - "deref", - map_ranges, - inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, - outputs={"_out": dace.Memlet.from_array(result_name, result_array)}, - code=f"_out[{index_name}] = _inp[{src_subset}]", - external_edges=True, - input_nodes={node.data: node for node in deref_nodes}, - output_nodes={result_name: result_node}, - debuginfo=di, - ) - return [ValueExpr(result_node, iterator.dtype)] - - def _split_shift_args( - self, args: list[itir.Expr] - ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: - pairs = [args[i : i + 2] for i in range(0, len(args), 2)] - assert len(pairs) >= 1 - assert all(len(pair) == 2 for pair in pairs) - return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None - - def _make_shift_for_rest(self, rest, iterator): - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), - args=[iterator], - location=iterator.location, - ) - - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - shift = node.fun - assert isinstance(shift, itir.FunCall) - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR pass is able to catch it - assert isinstance(iterator, list) and len(iterator) == 1 - assert isinstance(iterator[0], ValueExpr) - return iterator - - assert isinstance(tail[0], itir.OffsetLiteral) - offset_dim = tail[0].value - assert isinstance(offset_dim, str) - offset_node = self.visit(tail[1])[0] - assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - - if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): - offset_provider_type = cast( - common.NeighborConnectivityType, self.offset_provider_type[offset_dim] - ) # ensured by condition - connectivity = self.context.state.add_access( - dace_utils.connectivity_identifier(offset_dim), debuginfo=di - ) - - shifted_dim_tag = offset_provider_type.source_dim.value - target_dim_tag = offset_provider_type.codomain.value - args = [ - ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), - offset_node, - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" - else: - shifted_dim = self.offset_provider_type[offset_dim] - assert isinstance(shifted_dim, common.Dimension) - - shifted_dim_tag = shifted_dim.value - target_dim_tag = shifted_dim_tag - args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {internals[1]}" - - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim_tag] - shifted_index[target_dim_tag] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - offset = node.value - assert isinstance(offset, int) - offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var, debuginfo=di) - tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di - ) - self.context.state.add_edge( - tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0") - ) - return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] - - def _visit_reduce(self, node: itir.FunCall): - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - reduce_dtype = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - - if len(node.args) == 1: - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - reduce_identity = node.fun.args[1] - assert isinstance(reduce_identity, itir.Literal) - - # set reduction state - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args[0]) - - assert 1 <= len(args) <= 2 - reduce_input_node = args[0].value - - else: - assert isinstance(node.fun, itir.FunCall) - assert isinstance(node.fun.args[0], itir.Lambda) - fun_node = node.fun.args[0] - assert isinstance(fun_node.expr, itir.FunCall) - - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) - reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - - # set reduction state in visit context - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args) - - # clear context - self.context.reduce_identity = None - - # check that all neighbor expressions have the same shape - args_shape = [ - arg[0].value.desc(self.context.body).shape - for arg in args - if arg[0].value.desc(self.context.body).shape != (1,) - ] - assert len(set(args_shape)) == 1 - nreduce_shape = args_shape[0] - - input_args = [arg[0] for arg in args] - input_valid_args = [arg[1] for arg in args if len(arg) == 2] - - assert len(nreduce_shape) == 1 - nreduce_index = unique_name("_i") - nreduce_domain = {nreduce_index: f"0:{nreduce_shape[0]}"} - - reduce_input_name = unique_var_name() - self.context.body.add_array( - reduce_input_name, nreduce_shape, reduce_dtype, transient=True - ) - - lambda_node = itir.Lambda( - expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location - ) - lambda_context, inner_inputs, inner_outputs = self.visit( - lambda_node, args=input_args, use_neighbor_tables=False - ) - - input_mapping = { - param: ( - dace.Memlet(data=arg.value.data, subset="0") - if arg.value.desc(self.context.body).shape == (1,) - else dace.Memlet(data=arg.value.data, subset=nreduce_index) - ) - for (param, _), arg in zip(inner_inputs, input_args) - } - output_mapping = { - inner_outputs[0].value.data: dace.Memlet( - data=reduce_input_name, subset=nreduce_index - ) - } - symbol_mapping = map_nested_sdfg_symbols( - self.context.body, lambda_context.body, input_mapping - ) - - if input_valid_args: - """ - The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. - These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select - the result of field access or the identity value, respectively. - If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the select tasklet below is also skipped. - """ - input_args.append(input_valid_args[0]) - input_valid_node = input_valid_args[0].value - lambda_output_node = inner_outputs[0].value - # add input connector to nested sdfg - lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) - input_mapping["_valid_neighbor"] = dace.Memlet( - data=input_valid_node.data, subset=nreduce_index - ) - # add select tasklet before writing to output node - # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API - output_edge = lambda_context.state.in_edges(lambda_output_node)[0] - assert isinstance( - lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar - ) - select_tasklet = lambda_context.state.add_tasklet( - "neighbor_select", - {"_inp", "_valid"}, - {"_out"}, - f"_out = _inp if _valid else {reduce_identity}", - ) - lambda_context.state.add_edge( - output_edge.src, - None, - select_tasklet, - "_inp", - dace.Memlet(data=output_edge.src.data, subset="0"), - ) - lambda_context.state.add_edge( - lambda_context.state.add_access("_valid_neighbor"), - None, - select_tasklet, - "_valid", - dace.Memlet(data="_valid_neighbor", subset="0"), - ) - lambda_context.state.add_edge( - select_tasklet, - "_out", - lambda_output_node, - None, - dace.Memlet(data=lambda_output_node.data, subset="0"), - ) - lambda_context.state.remove_edge(output_edge) - - reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) - - nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( - self.context.state, - sdfg=lambda_context.body, - map_ranges=nreduce_domain, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - input_nodes={arg.value.data: arg.value for arg in input_args}, - output_nodes={reduce_input_name: reduce_input_node}, - debuginfo=di, - ) - - reduce_input_desc = reduce_input_node.desc(self.context.body) - - result_name = unique_var_name() - # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node - self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) - self.context.state.add_nedge( - reduce_input_node, - reduce_node, - dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), - ) - self.context.state.add_nedge( - reduce_node, result_access, dace.Memlet(data=result_name, subset="0") - ) - - return [ValueExpr(result_access, reduce_dtype)] - - def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = flatten_list(self.visit(node.args)) - expr_args = [ - (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) - ] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = fmt.format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return self.add_expr_tasklet( - expr_args, - expr, - type_, - "numeric", - dace_debuginfo=dace_utils.debug_info(node, default=self.context.body.debuginfo), - ) - - def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - expr_func = _GENERAL_BUILTIN_MAPPING[str(node.fun.id)] - return expr_func(self, node, node.args) - - def add_expr_tasklet( - self, - args: list[tuple[ValueExpr, str]], - expr: str, - result_type: Any, - name: str, - dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, - ) -> list[ValueExpr]: - di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - expr_tasklet = self.context.state.add_tasklet( - name=name, - inputs={internal for _, internal in args}, - outputs={"__result"}, - code=f"__result = {expr}", - debuginfo=di, - ) - - for arg, internal in args: - edges = self.context.state.in_edges(expr_tasklet) - used = False - for edge in edges: - if edge.dst_conn == internal: - used = True - break - if used: - continue - elif not isinstance(arg, SymbolExpr): - memlet = dace.Memlet.from_array( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - - memlet = dace.Memlet(data=result_access.data, subset="0") - self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) - - return [ValueExpr(result_access, result_type)] - - -def is_scan(node: itir.Node) -> bool: - return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") - - -def closure_to_tasklet_sdfg( - node: itir.StencilClosure, - offset_provider_type: common.OffsetProviderType, - domain: dict[str, str], - inputs: Sequence[tuple[str, ts.TypeSpec]], - connectivities: Sequence[tuple[dace.ndarray, str]], - use_field_canonical_representation: bool, -) -> tuple[Context, Sequence[ValueExpr]]: - body = dace.SDFG("tasklet_toplevel") - body.debuginfo = dace_utils.debug_info(node) - state = body.add_state("tasklet_toplevel_entry", True) - symbol_map: dict[str, TaskletExpr] = {} - - idx_accesses = {} - for dim, idx in domain.items(): - name = f"{idx}_value" - body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet( - f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo - ) - access = state.add_access(name, debuginfo=body.debuginfo) - idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) - for name, ty in inputs: - if isinstance(ty, ts.FieldType): - ndim = len(ty.dims) - shape, strides = new_array_symbols(name, ndim) - dims = [dim.value for dim in ty.dims] - dtype = dace_utils.as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name, debuginfo=body.debuginfo) - indices = {dim: idx_accesses[dim] for dim in domain.keys()} - symbol_map[name] = IteratorExpr(field, indices, dtype, dims) - else: - assert isinstance(ty, ts.ScalarType) - dtype = dace_utils.as_dace_type(ty) - body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) - for arr, name in connectivities: - shape, strides = new_array_symbols(name, ndim=2) - body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) - - context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen( - offset_provider_type, context, use_field_canonical_representation - ) - - args = [itir.SymRef(id=name) for name, _ in inputs] - if is_scan(node.stencil): - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda( - expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location - ) - fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) - else: - fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) - - results = translator.visit(fun_node) - for r in results: - context.body.arrays[r.value.data].transient = False - - return context, results diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py deleted file mode 100644 index 72bb32f003..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ /dev/null @@ -1,149 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import itertools -from typing import Any - -import dace - -import gt4py.next.iterator.ir as itir -from gt4py import eve -from gt4py.next import common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils - - -def get_used_connectivities( - node: itir.Node, offset_provider_type: common.OffsetProviderType -) -> dict[str, common.NeighborConnectivityType]: - connectivities = dace_utils.filter_connectivity_types(offset_provider_type) - offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) - return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} - - -def map_nested_sdfg_symbols( - parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] -) -> dict[str, str]: - symbol_mapping: dict[str, str] = {} - for param, arg in array_mapping.items(): - arg_array = parent_sdfg.arrays[arg.data] - param_array = nested_sdfg.arrays[param] - if not isinstance(param_array, dace.data.Scalar): - assert len(arg.subset.size()) == len(param_array.shape) - for arg_shape, param_shape in zip(arg.subset.size(), param_array.shape): - if isinstance(param_shape, dace.symbol): - symbol_mapping[str(param_shape)] = str(arg_shape) - assert len(arg_array.strides) == len(param_array.strides) - for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): - if isinstance(param_stride, dace.symbol): - symbol_mapping[str(param_stride)] = str(arg_stride) - else: - assert arg.subset.num_elements() == 1 - for sym in nested_sdfg.free_symbols: - if str(sym) not in symbol_mapping: - symbol_mapping[str(sym)] = str(sym) - return symbol_mapping - - -def add_mapped_nested_sdfg( - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, -) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - - -def unique_name(prefix): - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" - - -def unique_var_name(): - return unique_name("_var") - - -def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - shape = [dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) for i in range(ndim)] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i in range(ndim) - ] - return shape, strides - - -def flatten_list(node_list: list[Any]) -> list[Any]: - return list( - itertools.chain.from_iterable( - [flatten_list(e) if isinstance(e, list) else [e] for e in node_list] - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py deleted file mode 100644 index 653ed4719d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ /dev/null @@ -1,150 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import functools -from typing import Callable, Optional, Sequence - -import dace -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import program_to_fencil -from gt4py.next.otf import languages, recipes, stages, step_types, workflow -from gt4py.next.otf.binding import interface -from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.type_system import type_specifications as ts - -from . import build_sdfg_from_itir - - -@dataclasses.dataclass(frozen=True) -class DaCeTranslator( - workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] - ], - step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], -): - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None - use_field_canonical_representation: bool = False - - def _language_settings(self) -> languages.LanguageSettings: - return languages.LanguageSettings( - formatter_key="", formatter_style="", file_extension="sdfg" - ) - - def generate_sdfg( - self, - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - column_axis: Optional[common.Dimension], - ) -> dace.SDFG: - on_gpu = ( - True - if self.device_type in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] - else False - ) - - return build_sdfg_from_itir( - program, - arg_types, - offset_provider_type=offset_provider_type, - auto_optimize=self.auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, - load_sdfg_from_file=False, - save_sdfg=False, - use_field_canonical_representation=self.use_field_canonical_representation, - ) - - def __call__( - self, inp: stages.CompilableProgram - ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: - """Generate DaCe SDFG file from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data - - if isinstance(program, itir.Program): - program = program_to_fencil.program_to_fencil(program) - - sdfg = self.generate_sdfg( - program, - inp.args.args, - common.offset_provider_to_type(inp.args.offset_provider), - inp.args.column_axis, - ) - - param_types = tuple( - interface.Parameter(param, arg) for param, arg in zip(sdfg.arg_names, inp.args.args) - ) - - module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( - stages.ProgramSource( - entry_point=interface.Function(program.id, param_types), - source_code=sdfg.to_json(), - library_deps=tuple(), - language=languages.SDFG, - language_settings=self._language_settings(), - implicit_domain=inp.data.implicit_domain, - ) - ) - return module - - -class DaCeTranslationStepFactory(factory.Factory): - class Meta: - model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - use_field_canonical_representation: bool = False - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - use_field_canonical_representation=o.use_field_canonical_representation, - ) - ) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 349d3e9f70..1593ab3ba6 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -11,11 +11,10 @@ import dataclasses import enum import importlib -from typing import Final, Optional, Protocol import pytest -from gt4py.next import allocators as next_allocators, backend as next_backend +from gt4py.next import allocators as next_allocators # Skip definitions @@ -67,10 +66,10 @@ class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): - DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" - DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" - GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" - GTIR_DACE_GPU = "gt4py.next.program_processors.runners.dace.gtir_gpu" + DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" + DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" + DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -139,21 +138,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), -] -GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ +DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), @@ -189,10 +174,16 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST + + [ + (ALL, SKIP, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): Enable once the optimization pipeline is merged + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + (ALL, SKIP, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): Enable once the optimization pipeline is merged. + OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index f5646c71e4..08904c06f3 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,14 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from types import ModuleType -from typing import Optional - import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend, common +from gt4py.next import allocators as gtx_allocators, common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case @@ -34,24 +31,22 @@ try: import dace - from gt4py.next.program_processors.runners.dace import ( - itir_cpu as run_dace_cpu, - itir_gpu as run_dace_gpu, - ) except ImportError: dace: Optional[ModuleType] = None # type:ignore[no-redef] - run_dace_cpu: Optional[next_backend.Backend] = None - run_dace_gpu: Optional[next_backend.Backend] = None pytestmark = pytest.mark.requires_dace def test_sdfgConvertible_laplap(cartesian_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if cartesian_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - if cartesian_case.backend == run_dace_gpu: + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + + allocator, backend = unstructured_case.allocator, unstructured_case.backend + + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp else: import numpy as xp @@ -64,13 +59,13 @@ def test_sdfgConvertible_laplap(cartesian_case): def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( in_field, tmp_field ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( tmp_field, out_field ) @@ -94,13 +89,15 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift def test_sdfgConvertible_connectivities(unstructured_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if unstructured_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + allocator, backend = unstructured_case.allocator, unstructured_case.backend - if backend == run_dace_gpu: + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 794dd06709..1147f4bc3e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -66,11 +66,11 @@ def __gt_allocator__( marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.DACE_GPU_NO_OPT, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ], From a936761243319dbfa2c94e28222dc6962a96f14a Mon Sep 17 00:00:00 2001 From: SF-N Date: Wed, 4 Dec 2024 15:05:43 +0100 Subject: [PATCH 30/43] bug[next]: Fix astype for local fields (#1761) Fix astype by calling `_map` additionally and add corresponding tests Co-authored-by: Edoardo Paone --- src/gt4py/next/ffront/foast_to_gtir.py | 6 +-- .../dace_fieldview/gtir_python_codegen.py | 46 ++++++++++++------- .../ffront_tests/test_execution.py | 16 +++++++ .../ffront_tests/test_foast_to_gtir.py | 16 ++++++- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2c2971f49a..3c65695aec 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -359,11 +359,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: - if isinstance(t[0], ts.FieldType): - return im.cast_as_fieldop(str(new_type))(expr) - else: - assert isinstance(t[0], ts.ScalarType) - return im.call("cast_")(expr, str(new_type)) + return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj, (node.args[0].type,)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 6aee33c56e..4bdb602f5f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -118,29 +118,41 @@ class PythonCodegen(codegen.TemplatedGenerator): as in the case of field domain definitions, for sybolic array shape and map range. """ - SymRef = as_fmt("{id}") Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def visit_FunCall(self, node: gtir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) + def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: + if isinstance(node.fun, gtir.Lambda): + # update the mapping from lambda parameters to corresponding argument expressions + lambda_args_map = args_map | { + p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True) + } + return self.visit(node.fun.expr, args_map=lambda_args_map) + elif cpm.is_call_to(node, "deref"): + assert len(node.args) == 1 + if not isinstance(node.args[0], gtir.SymRef): + # shift expressions are not expected in this visitor context + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + return self.visit(node.args[0], args_map=args_map) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args) + args = self.visit(node.args, args_map=args_map) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: + symbol = str(node.id) + if symbol in args_map: + return self.visit(args_map[symbol], args_map=args_map) + return symbol + -get_source = PythonCodegen.apply -""" -Specialized visit method for symbolic expressions. +def get_source(node: gtir.Node) -> str: + """ + Specialized visit method for symbolic expressions. -Returns: - A string containing the Python code corresponding to a symbolic expression -""" + The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions. + + Returns: + A string containing the Python code corresponding to a symbolic expression + """ + return PythonCodegen.apply(node, args_map={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 0d994d1b22..4eed7f5cde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -438,6 +438,22 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +def test_astype_int_local_field(unstructured_case): + @gtx.field_operator + def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: + tmp = astype(a(E2V), int64) + return neighbor_sum(tmp, axis=E2VDim) + + e2v_table = unstructured_case.offset_provider["E2V"].ndarray + + cases.verify_with_default_data( + unstructured_case, + testee, + ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0), + comparison=lambda a, b: np.all(a == b), + ) + + @pytest.mark.uses_tuple_returns def test_astype_on_tuples(cartesian_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 516890ea46..59a8dc961b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -283,9 +283,22 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.cast_as_fieldop("int32")("a") + assert lowered_inlined.expr == reference + + +def test_astype_local_field(): + def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a") + assert lowered.expr == reference @@ -295,10 +308,11 @@ def foo(a: float64): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.call("cast_")("a", "int32") - assert lowered.expr == reference + assert lowered_inlined.expr == reference def test_astype_tuple(): From 10adb2c7b3d26a31f9580218d2f8edc6fe67abbf Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 4 Dec 2024 15:09:45 +0100 Subject: [PATCH 31/43] test[next]: cleanup test markers (#1767) - Remove some test markers related to ITIR. - Fuse `uses_index_builtin` marker into `uses_index_fields`. --- pyproject.toml | 3 --- tests/next_tests/definitions.py | 5 ----- .../feature_tests/ffront_tests/test_execution.py | 4 ---- .../feature_tests/iterator_tests/test_program.py | 7 ++----- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e24094fa2..e859c9b4f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -240,17 +240,14 @@ markers = [ 'requires_atlas: tests that require `atlas4py` bindings package', 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', - 'starts_from_gtir_program: tests that require backend to start lowering from GTIR program', 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', - 'uses_lift_expressions: tests that require backend support for lift expressions', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', - 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 1593ab3ba6..80b8f4f39b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -85,8 +85,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # to avoid needing to mark all tests. ALL = "all" REQUIRES_ATLAS = "requires_atlas" -# TODO(havogt): Remove, skipped during refactoring to GTIR -STARTS_FROM_GTIR_PROGRAM = "starts_from_gtir_program" USES_APPLIED_SHIFTS = "uses_applied_shifts" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" @@ -94,10 +92,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_IF_STMTS = "uses_if_stmts" USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" -USES_LIFT_EXPRESSIONS = "uses_lift_expressions" USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" -USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" @@ -117,7 +113,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" -USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 4eed7f5cde..9de4449ac2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -291,7 +291,6 @@ def testee(a: tuple[cases.VField, cases.EField]) -> cases.VField: ) -@pytest.mark.uses_index_fields @pytest.mark.uses_cartesian_shift def test_scalar_arg_with_field(cartesian_case): @gtx.field_operator @@ -602,7 +601,6 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator def testee(a: cases.VField) -> cases.VField: @@ -722,7 +720,6 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_scan -@pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) @@ -804,7 +801,6 @@ def testee( @pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index f6fd0a48d0..c79f8dbb6b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -42,7 +42,6 @@ def copy_program(inp, out, size): ) -@pytest.mark.starts_from_gtir_program def test_prog(program_processor): program_processor, validate = program_processor @@ -64,8 +63,7 @@ def index_program_simple(out, size): ) -@pytest.mark.starts_from_gtir_program -@pytest.mark.uses_index_builtin +@pytest.mark.uses_index_fields def test_index_builtin(program_processor): program_processor, validate = program_processor @@ -88,8 +86,7 @@ def index_program_shift(out, size): ) -@pytest.mark.starts_from_gtir_program -@pytest.mark.uses_index_builtin +@pytest.mark.uses_index_fields def test_index_builtin_shift(program_processor): program_processor, validate = program_processor From ea616597483aff2b29a6195a7a11071f907dedbb Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 4 Dec 2024 15:36:24 +0100 Subject: [PATCH 32/43] test[next]: Disable iterator tests on DaCe GTIR backend (#1768) There are 2 places where the `program_processor` fixture used in tests is configured: ``` tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py tests/next_tests/unit_tests/conftest.py ``` and there are four DaCe backends: ``` DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" ``` The `DACE_CPU` and `DACE_GPU` backends will be the default backends, that also apply the DaCe optimization pipeline. However, these backends are disabled for now because we are awaiting #1639. The `DACE_CPU_NO_OPT` and `DACE_GPU_NO_OPT` apply the lowering to SDFG but do not run the optimization pipeline. These backends are currently enabled in GT4Py tests. However, we observed failures in some iterator tests controlled by the `program_processor` fixture in `tests/next_tests/unit_tests/conftest.py`, once `DACE_CPU` is enabled. In this PR, we are disabling such tests: we will address these issues in a separate PR. --- tests/next_tests/unit_tests/conftest.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index f1269f1ed8..99bc44efa7 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,15 +60,6 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), - pytest.param( - (next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True), - marks=pytest.mark.requires_dace, - ), - # TODO(havogt): update tests to use proper allocation - # pytest.param( - # (next_tests.definitions.OptionalProgramBackendId.DACE_GPU, True), - # marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), - # ), ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) From 33c5ba33e07923ed8830a30b2907598d1a32867d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:50:01 +0100 Subject: [PATCH 33/43] feat[dace]: Updated DaCe Transformations (#1639) The [initial version](https://github.com/GridTools/gt4py/pull/1594) of the optimization pipeline only contained a rough draft. Currently this PR contains a copy of the map fusion transformations from DaCe that are currently under [review](https://github.com/spcl/dace/pull/1629). As soon as that PR is merged and DaCe was updated in GT4Py these files will be deleted. This PR collects some general improvements: - [x] More liberal `LoopBlocking` transformation (with tests). - [x] Incorporate `MapFusionParallel` - [x] Using of `can_be_applied_to()` as soon as DaCe is updated (`TrivialGPUMapElimination`, `SerialMapPromoter`). - [x] Looking at strides that the Lowering generates. (Partly done) However, it still uses MapFusion implementation that ships with GT4Py and not the one in DaCe. Note: Because of commit 60e4226 this PR must be merged after [PR1768](https://github.com/GridTools/gt4py/pull/1768). --- .../transformations/__init__.py | 42 +- .../{auto_opt.py => auto_optimize.py} | 271 ++--- .../transformations/gpu_utils.py | 666 ++++++++--- .../transformations/local_double_buffering.py | 393 +++++++ .../transformations/loop_blocking.py | 284 ++--- .../transformations/map_fusion_helper.py | 882 ++++++++------ .../transformations/map_fusion_parallel.py | 170 +++ .../transformations/map_fusion_serial.py | 1007 ++++++++++++++++ .../transformations/map_orderer.py | 144 ++- .../transformations/map_promoter.py | 42 +- .../transformations/map_serial_fusion.py | 483 -------- .../transformations/simplify.py | 1010 +++++++++++++++++ .../dace_fieldview/transformations/strides.py | 99 ++ .../dace_fieldview/transformations/util.py | 317 ++++-- tests/next_tests/definitions.py | 10 +- .../transformation_tests/conftest.py | 4 +- .../test_constant_substitution.py | 142 +++ .../test_create_local_double_buffering.py | 239 ++++ .../test_distributed_buffer_relocator.py | 84 ++ .../test_global_self_copy_elimination.py | 148 +++ .../transformation_tests/test_gpu_utils.py | 108 +- .../test_loop_blocking.py | 508 ++++++++- .../test_map_buffer_elimination.py | 264 +++++ .../transformation_tests/test_map_fusion.py | 124 +- .../transformation_tests/test_map_order.py | 100 ++ .../test_move_tasklet_into_map.py | 164 +++ .../test_serial_map_promoter.py | 4 +- .../dace_tests/transformation_tests/util.py | 6 +- 28 files changed, 6118 insertions(+), 1597 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/transformations/{auto_opt.py => auto_optimize.py} (67%) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 8852dd6d2d..2232bcef01 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,32 +12,56 @@ that explains the general structure and requirements on the SDFGs. """ -from .auto_opt import ( +from .auto_optimize import gt_auto_optimize +from .gpu_utils import ( + GPUSetBlockSize, + gt_gpu_transform_non_standard_memlet, + gt_gpu_transformation, + gt_set_gpu_blocksize, +) +from .local_double_buffering import gt_create_local_double_buffering +from .loop_blocking import LoopBlocking +from .map_fusion_parallel import MapFusionParallel +from .map_fusion_serial import MapFusionSerial +from .map_orderer import MapIterationOrder, gt_set_iteration_order +from .map_promoter import SerialMapPromoter +from .simplify import ( GT_SIMPLIFY_DEFAULT_SKIP_SET, - gt_auto_optimize, + GT4PyGlobalSelfCopyElimination, + GT4PyMapBufferElimination, + GT4PyMoveTaskletIntoMap, gt_inline_nested_sdfg, - gt_set_iteration_order, + gt_reduce_distributed_buffering, gt_simplify, + gt_substitute_compiletime_symbols, ) -from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize -from .loop_blocking import LoopBlocking -from .map_orderer import MapIterationOrder -from .map_promoter import SerialMapPromoter -from .map_serial_fusion import SerialMapFusion +from .strides import gt_change_transient_strides +from .util import gt_find_constant_arguments, gt_make_transients_persistent __all__ = [ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", + "GT4PyGlobalSelfCopyElimination", + "GT4PyMoveTaskletIntoMap", + "GT4PyMapBufferElimination", "LoopBlocking", "MapIterationOrder", - "SerialMapFusion", + "MapFusionParallel", + "MapFusionSerial", "SerialMapPromoter", "SerialMapPromoterGPU", "gt_auto_optimize", + "gt_change_transient_strides", + "gt_create_local_double_buffering", "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_set_iteration_order", "gt_set_gpu_blocksize", "gt_simplify", + "gt_make_transients_persistent", + "gt_reduce_distributed_buffering", + "gt_find_constant_arguments", + "gt_substitute_compiletime_symbols", + "gt_gpu_transform_non_standard_memlet", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py similarity index 67% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index e070cdfe4e..bc1d21ca05 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -8,10 +8,10 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Final, Iterable, Optional, Sequence +from typing import Any, Optional, Sequence, Union import dace -from dace.transformation import dataflow as dace_dataflow, passes as dace_passes +from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize from gt4py.next import common as gtx_common @@ -20,146 +20,12 @@ ) -GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} -"""Set of simplify passes `gt_simplify()` skips by default. - -The following passes are included: -- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a - symbol or vice versa and at a later point to invert this again. However, this - pass has some problems with this pattern so for the time being it is disabled. -- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. -""" - - -def gt_simplify( - sdfg: dace.SDFG, - validate: bool = True, - validate_all: bool = False, - skip: Optional[Iterable[str]] = None, -) -> Any: - """Performs simplifications on the SDFG in place. - - Instead of calling `sdfg.simplify()` directly, you should use this function, - as it is specially tuned for GridTool based SDFGs. - - This function runs the DaCe simplification pass, but the following passes are - replaced: - - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. - - Furthermore, by default, or if `None` is passed fro `skip` the passes listed in - `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. - - Args: - sdfg: The SDFG to optimize. - validate: Perform validation after the pass has run. - validate_all: Perform extensive validation. - skip: List of simplify passes that should not be applied, defaults - to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. - """ - # Ensure that `skip` is a `set` - skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) - - if "InlineSDFGs" not in skip: - gt_inline_nested_sdfg( - sdfg=sdfg, - multistate=True, - permissive=False, - validate=validate, - validate_all=validate_all, - ) - - return dace_passes.SimplifyPass( - validate=validate, - validate_all=validate_all, - verbose=False, - skip=(skip | {"InlineSDFGs"}), - ).apply_pass(sdfg, {}) - - -def gt_set_iteration_order( - sdfg: dace.SDFG, - leading_dim: gtx_common.Dimension, - validate: bool = True, - validate_all: bool = False, -) -> Any: - """Set the iteration order of the Maps correctly. - - Modifies the order of the Map parameters such that `leading_dim` - is the fastest varying one, the order of the other dimensions in - a Map is unspecific. `leading_dim` should be the dimensions were - the stride is one. - - Args: - sdfg: The SDFG to process. - leading_dim: The leading dimensions. - validate: Perform validation during the steps. - validate_all: Perform extensive validation. - """ - return sdfg.apply_transformations_once_everywhere( - gtx_transformations.MapIterationOrder( - leading_dim=leading_dim, - ), - validate=validate, - validate_all=validate_all, - ) - - -def gt_inline_nested_sdfg( - sdfg: dace.SDFG, - multistate: bool = True, - permissive: bool = False, - validate: bool = True, - validate_all: bool = False, -) -> dace.SDFG: - """Perform inlining of nested SDFG into their parent SDFG. - - The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. - However, before the inline transformation is run the function will run some - cleaning passes that allows inlining nested SDFGs. - As a side effect, the function will split stages into more states. - - Args: - sdfg: The SDFG that should be processed, will be modified in place and returned. - multistate: Allow inlining of multistate nested SDFG, defaults to `True`. - permissive: Be less strict on the accepted SDFGs. - validate: Perform validation after the transformation has finished. - validate_all: Performs extensive validation. - """ - first_iteration = True - i = 0 - while True: - print(f"ITERATION: {i}") - nb_preproccess = sdfg.apply_transformations_repeated( - [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], - validate=False, - validate_all=validate_all, - ) - if (nb_preproccess == 0) and (not first_iteration): - break - - # Create and configure the inline pass - inline_sdfg = dace_passes.InlineSDFGs() - inline_sdfg.progress = False - inline_sdfg.permissive = permissive - inline_sdfg.multistate = multistate - - # Apply the inline pass - nb_inlines = inline_sdfg.apply_pass(sdfg, {}) - - # Check result, if needed and test if we can stop - if validate_all or validate: - sdfg.validate() - if nb_inlines == 0: - break - first_iteration = False - - return sdfg - - def gt_auto_optimize( sdfg: dace.SDFG, gpu: bool, - leading_dim: Optional[gtx_common.Dimension] = None, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, aggressive_fusion: bool = True, max_optimization_rounds_p2: int = 100, make_persistent: bool = True, @@ -169,6 +35,8 @@ def gt_auto_optimize( reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, + constant_symbols: Optional[dict[str, Any]] = None, + assume_pointwise: bool = True, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -184,6 +52,9 @@ def gt_auto_optimize( different aspects of the SDFG. The initial SDFG is assumed to have a very large number of rather simple Maps. + Note, because of how `gt_auto_optimizer()` works it is not save to call + it twice on the same SDFG. + 1. Some general simplification transformations, beyond classical simplify, are applied to the SDFG. 2. Tries to create larger kernels by fusing smaller ones, see @@ -223,20 +94,31 @@ def gt_auto_optimize( gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` for _all_ GPU Maps. + constant_symbols: Symbols listed in this `dict` will be replaced by the + respective value inside the SDFG. This might increase performance. + assume_pointwise: Assume that the SDFG has no risk for race condition in + global data access. See the `GT4PyMapBufferElimination` transformation for more. validate: Perform validation during the steps. validate_all: Perform extensive validation. + + Note: + For identifying symbols that can be treated as compile time constants + `gt_find_constant_arguments()` function can be used. + Todo: - - Make sure that `SDFG.simplify()` is not called indirectly, by temporarily - overwriting it with `gt_simplify()`. + - Update the description. The Phases are nice, but they have lost their + link to reality a little bit. + - Improve the determination of the strides and iteration order of the + transients. + - Set padding of transients, i.e. alignment, the DaCe datadescriptor + can do that. + - Handle nested SDFGs better. - Specify arguments to set the size of GPU thread blocks depending on the dimensions. I.e. be able to use a different size for 1D than 2D Maps. - - Add a parallel version of Map fusion. - Implement some model to further guide to determine what we want to fuse. Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - - Create a custom array elimination pass that honors rule 1. - - Check if a pipeline could be used to speed up some computations. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU @@ -249,20 +131,25 @@ def gt_auto_optimize( # to internal serial maps, such that they do not block fusion? # Phase 1: Initial Cleanup - gt_simplify( + gtx_transformations.gt_simplify( sdfg=sdfg, validate=validate, validate_all=validate_all, ) + gtx_transformations.gt_reduce_distributed_buffering(sdfg) + + if constant_symbols: + gtx_transformations.gt_substitute_compiletime_symbols( + sdfg=sdfg, + repl=constant_symbols, + validate=validate, + validate_all=validate_all, + ) + gtx_transformations.gt_simplify(sdfg) + sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, - # TODO(phimuell): The transformation are interesting, but they have - # a bug as they assume that they are not working inside a map scope. - # Before we use them we have to fix them. - # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc - # dace_dataflow.MapReduceFusion, - # dace_dataflow.MapWCRFusion, ], validate=validate, validate_all=validate_all, @@ -278,28 +165,62 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. - # Currently this only applies fusion inside Maps. + # After we have created big kernels, we will perform some post cleanup. + gtx_transformations.gt_reduce_distributed_buffering(sdfg) sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_inner_maps=True, - ), + [ + gtx_transformations.GT4PyMoveTaskletIntoMap, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), + ], validate=validate, validate_all=validate_all, ) - gt_simplify(sdfg) + + # TODO(phimuell): The `MapReduceFusion` transformation is interesting as + # it moves the initialization of the accumulator at the top, which allows + # further fusing of the accumulator loop. However the transformation has + # a bug, so we can not use it. Furthermore, I have looked at the assembly + # and the compiler is already doing that. + # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc + + # After we have created large kernels we run `dace_dataflow.MapReduceFusion`. + + # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. + # Currently this only applies fusion inside Maps. + gtx_transformations.gt_simplify(sdfg) + while True: + nb_applied = sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_inner_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_inner_maps=True, + only_if_common_ancestor=False, # TODO(phimuell): Should we? + ), + ], + validate=validate, + validate_all=validate_all, + ) + if not nb_applied: + break + gtx_transformations.gt_simplify(sdfg) # Phase 4: Iteration Space # This essentially ensures that the stride 1 dimensions are handled # by the inner most loop nest (CPU) or x-block (GPU) if leading_dim is not None: - gt_set_iteration_order( + gtx_transformations.gt_set_iteration_order( sdfg=sdfg, leading_dim=leading_dim, validate=validate, validate_all=validate_all, ) + # We now ensure that point wise computations are properly double buffered. + # The main reason is to ensure that rule 3 of ADR18 is maintained. + gtx_transformations.gt_create_local_double_buffering(sdfg) + # Phase 5: Apply blocking if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( @@ -342,9 +263,23 @@ def gt_auto_optimize( dace_aoptimize.set_fast_implementations(sdfg, device) # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) + + # Now we modify the strides. + gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + if make_persistent: - # TODO(phimuell): Allow to also to set the lifetime to `SDFG`. - dace_aoptimize.make_transients_persistent(sdfg, device) + gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) + + if device == dace.DeviceType.GPU: + # NOTE: For unknown reasons the counterpart of the + # `gt_make_transients_persistent()` function in DaCe, resets the + # `wcr_nonatomic` property of every memlet, i.e. makes it atomic. + # However, it does this only for edges on the top level and on GPU. + # For compatibility with DaCe (and until we found out why) the GT4Py + # auto optimizer will emulate this behaviour. + for state in sdfg.states(): + for edge in state.edges(): + edge.data.wcr_nonatomic = False return sdfg @@ -395,9 +330,17 @@ def gt_auto_fuse_top_level_maps( # TODO(phimuell): Add parallel fusion transformation. Should it run after # or with the serial one? sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + # This will lead to the creation of big probably unrelated maps. + # However, it might be good. + only_if_common_ancestor=False, + ), + ], validate=validate, validate_all=validate_all, ) @@ -437,7 +380,7 @@ def gt_auto_fuse_top_level_maps( # The SDFG was modified by the transformations above. The SDFG was # modified. Call Simplify and try again to further optimize. - gt_simplify(sdfg, validate=validate, validate_all=validate_all) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) else: raise RuntimeWarning("Optimization of the SDFG did not converge.") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 16c9600a3a..2cd3020180 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -11,10 +11,15 @@ from __future__ import annotations import copy -from typing import Any, Optional, Sequence, Union +from typing import Any, Callable, Final, Optional, Sequence, Union import dace -from dace import properties as dace_properties, transformation as dace_transformation +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + transformation as dace_transformation, +) +from dace.codegen.targets import cpp as dace_cpp from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -51,7 +56,9 @@ def gt_gpu_transformation( will avoid the data copy from host to GPU memory. gpu_block_size: The size of a thread block on the GPU. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. + Will only take effect if `gpu_block_size` is specified. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` + Will only take effect if `gpu_block_size` is specified. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -82,39 +89,197 @@ def gt_gpu_transformation( validate_all=validate_all, simplify=False, ) + # The documentation recommends to run simplify afterwards gtx_transformations.gt_simplify(sdfg) if try_removing_trivial_maps: - # A Tasklet, outside of a Map, that writes into an array on GPU can not work - # `sdfg.appyl_gpu_transformations()` puts Map around it (if said Tasklet - # would write into a Scalar that then goes into a GPU Map, nothing would - # happen. So we might end up with lot of these trivial Maps, that results - # in a single kernel launch. To prevent this we will try to fuse them. - # NOTE: The current implementation has a bug, because promotion and fusion - # are two different steps. Because of this the function will implicitly - # fuse everything together it can find. - # TODO(phimuell): Fix the issue described above. + # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on + # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # we might end up with lots of these trivial Maps, each requiring a separate + # kernel launch. To prevent this we will combine these trivial maps, if + # possible, with their downstream maps. sdfg.apply_transformations_once_everywhere( - TrivialGPUMapPromoter(), + TrivialGPUMapElimination(), validate=False, validate_all=False, ) - sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), - validate=validate, - validate_all=validate_all, - ) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + # TODO(phimuell): Fixing the stride problem. + sdfg = gt_gpu_transform_non_standard_memlet( + sdfg=sdfg, + map_postprocess=True, + validate=validate, + validate_all=validate_all, + ) # Set the GPU block size if it is known. if gpu_block_size is not None: gt_set_gpu_blocksize( sdfg=sdfg, - gpu_block_size=gpu_block_size, - gpu_launch_bounds=gpu_launch_bounds, - gpu_launch_factor=gpu_launch_factor, + block_size=gpu_block_size, + launch_bounds=gpu_launch_bounds, + launch_factor=gpu_launch_factor, + ) + + if validate_all or validate: + sdfg.validate() + + return sdfg + + +def gt_gpu_transform_non_standard_memlet( + sdfg: dace.SDFG, + map_postprocess: bool, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Transform some non standard Melets to Maps. + + The GPU code generator is not able to handle certain sets of Memlets. To + handle them, the code generator transforms them into copy Maps. The main + issue is that this transformation happens after the auto optimizer, thus + the copy-Maps will most likely have the wrong iteration order. + + This function allows to perform the preprocessing step before the actual + code generation. The function will perform the expansion. If + `map_postprocess` is `True` then the function will also apply MapFusion, + to these newly created copy-Maps and set their iteration order correctly. + + A user should not call this function directly, instead this function is + called by the `gt_gpu_transformation()` function. + + Args: + sdfg: The SDFG that we process. + map_postprocess: Enable post processing of the maps that are created. + See the Note section below. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + + Note: + - Currently the function applies some crude heuristic to determine the + correct loop order. + - This function should be called after `gt_set_iteration_order()` has run. + """ + new_maps: set[dace_nodes.MapEntry] = set() + + # This code is is copied from DaCe's code generator. + for e, state in list(sdfg.all_edges_recursive()): + nsdfg = state.parent + if ( + isinstance(e.src, dace_nodes.AccessNode) + and isinstance(e.dst, dace_nodes.AccessNode) + and e.src.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global + and e.dst.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global + ): + a: dace_nodes.AccessNode = e.src + b: dace_nodes.AccessNode = e.dst + + copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides( + None, nsdfg, state, e, a, b + ) + dims = len(copy_shape) + if dims == 1: + continue + elif dims == 2: + if src_strides[-1] != 1 or dst_strides[-1] != 1: + try: + is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] + is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] + except (TypeError, ValueError): + is_src_cont = False + is_dst_cont = False + if is_src_cont and is_dst_cont: + continue + else: + continue + elif dims > 2: + if not (src_strides[-1] != 1 or dst_strides[-1] != 1): + continue + + # For identifying the new map, we first store all neighbors of `a`. + old_neighbors_of_a: list[dace_nodes.AccessNode] = [ + edge.dst for edge in state.out_edges(a) + ] + + # Turn unsupported copy to a map + try: + dace_transformation.dataflow.CopyToMap.apply_to( + nsdfg, save=False, annotate=False, a=a, b=b + ) + except ValueError: # If transformation doesn't match, continue normally + continue + + # We find the new map by comparing the new neighborhood of `a` with the old one. + new_nodes: set[dace_nodes.MapEntry] = { + edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a + } + assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) + assert len(new_nodes) == 1 + new_maps.update(new_nodes) + + # If there are no Memlets that are translated to copy-Maps, then we have nothing to do. + if len(new_maps) == 0: + return sdfg + + # This function allows to restrict any fusion operation to the maps + # that we have just created. + def restrict_fusion_to_newly_created_maps( + self: gtx_transformations.map_fusion_helper.MapFusionHelper, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool, + ) -> bool: + return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2]) + + # Using the callback to restrict the fusing + sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + # Now we have to find the maps that were not fused. We rely here on the fact + # that at least one of the map that is involved in fusing still exists. + maps_to_modify: set[dace_nodes.MapEntry] = set() + for nsdfg in sdfg.all_sdfgs_recursive(): + for state in nsdfg.states(): + for map_entry in state.nodes(): + if not isinstance(map_entry, dace_nodes.MapEntry): + continue + if map_entry in new_maps: + maps_to_modify.add(map_entry) + assert 0 < len(maps_to_modify) <= len(new_maps) + + # This is a gross hack, but it is needed, for the following reasons: + # - The transients have C order while the non-transients have (most + # likely) FORTRAN order. So there is not an unique stride dimension. + # - The newly created maps have names that does not reflect GT4Py dimensions, + # thus we can not use `gt_set_iteration_order()`. + # For these reasons we do the simplest thing, which is assuming that the maps + # are created in C order and we must make them in FORTRAN order, which means + # just swapping the order of the map parameters. + # TODO(phimuell): Do it properly. + for me_to_modify in maps_to_modify: + map_to_modify: dace_nodes.Map = me_to_modify.map + map_to_modify.params = list(reversed(map_to_modify.params)) + map_to_modify.range = dace.subsets.Range( + (r1, r2, r3, t) + for (r1, r2, r3), t in zip( + reversed(map_to_modify.range.ranges), reversed(map_to_modify.range.tile_sizes) + ) ) return sdfg @@ -122,131 +287,214 @@ def gt_gpu_transformation( def gt_set_gpu_blocksize( sdfg: dace.SDFG, - gpu_block_size: Optional[Sequence[int | str] | str], - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, + block_size: Optional[Sequence[int | str] | str], + launch_bounds: Optional[int | str] = None, + launch_factor: Optional[int] = None, + **kwargs: Any, ) -> Any: """Set the block size related properties of _all_ Maps. - See `GPUSetBlockSize` for more information. + It supports the same arguments as `GPUSetBlockSize`, however it also has + versions without `_Xd`, these are used as default for the other maps. + If a version with `_Xd` is specified then it takes precedence. Args: sdfg: The SDFG to process. - gpu_block_size: The size of a thread block on the GPU. + block_size: The size of a thread block on the GPU. launch_bounds: The value for the launch bound that should be used. launch_factor: If no `launch_bounds` was given use the number of threads in a block multiplied by this number. """ - xform = GPUSetBlockSize( - block_size=gpu_block_size, - launch_bounds=gpu_launch_bounds, - launch_factor=gpu_launch_factor, - ) - return sdfg.apply_transformations_once_everywhere([xform]) - - -def _gpu_block_parser( - self: GPUSetBlockSize, - val: Any, -) -> None: - """Used by the setter of `GPUSetBlockSize.block_size`.""" - org_val = val - if isinstance(val, (tuple | list)): - pass - elif isinstance(val, str): - val = tuple(x.strip() for x in val.split(",")) - elif isinstance(val, int): - val = (val,) - else: - raise TypeError( - f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." - ) - if 0 < len(val) <= 3: - val = [*val, *([1] * (3 - len(val)))] - else: - raise ValueError(f"Can not parse block size '{org_val}': wrong length") - try: - val = [int(x) for x in val] - except ValueError: - raise TypeError( - f"Currently only block sizes convertible to int are supported, you passed '{val}'." - ) from None - self._block_size = val - + for dim in [1, 2, 3]: + for arg, val in { + "block_size": block_size, + "launch_bounds": launch_bounds, + "launch_factor": launch_factor, + }.items(): + if f"{arg}_{dim}d" not in kwargs: + kwargs[f"{arg}_{dim}d"] = val + return sdfg.apply_transformations_once_everywhere(GPUSetBlockSize(**kwargs)) + + +def _make_gpu_block_parser_for( + dim: int, +) -> Callable[["GPUSetBlockSize", Any], None]: + """Generates a parser for GPU blocks for dimension `dim`. + + The returned function can be used as parser for the `GPUSetBlockSize.block_size_*d` + properties. + """ -def _gpu_block_getter( - self: "GPUSetBlockSize", -) -> tuple[int, int, int]: - """Used as getter in the `GPUSetBlockSize.block_size` property.""" - assert isinstance(self._block_size, (tuple, list)) and len(self._block_size) == 3 - assert all(isinstance(x, int) for x in self._block_size) - return tuple(self._block_size) + def _gpu_block_parser( + self: GPUSetBlockSize, + val: Any, + ) -> None: + """Used by the setter of `GPUSetBlockSize.block_size`.""" + org_val = val + if isinstance(val, (tuple | list)): + pass + elif isinstance(val, str): + val = tuple(x.strip() for x in val.split(",")) + elif isinstance(val, int): + val = (val,) + else: + raise TypeError( + f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." + ) + if len(val) < dim: + raise ValueError( + f"The passed block size only covers {len(val)} dimensions, but dimension was {dim}." + ) + if 0 < len(val) <= 3: + val = [*val, *([1] * (3 - len(val)))] + else: + raise ValueError(f"Can not parse block size '{org_val}': wrong length") + try: + val = [int(x) for x in val] + except ValueError: + raise TypeError( + f"Currently only block sizes convertible to int are supported, you passed '{val}'." + ) from None + + # Remove over specification. + for i in range(dim, 3): + val[i] = 1 + setattr(self, f"_block_size_{dim}d", tuple(val)) + + return _gpu_block_parser + + +def _make_gpu_block_getter_for( + dim: int, +) -> Callable[["GPUSetBlockSize"], tuple[int, int, int]]: + """Makes the getter for the block size of dimension `dim`.""" + + def _gpu_block_getter( + self: "GPUSetBlockSize", + ) -> tuple[int, int, int]: + """Used as getter in the `GPUSetBlockSize.block_size` property.""" + return getattr(self, f"_block_size_{dim}d") + + return _gpu_block_getter + + +def _gpu_launch_bound_parser( + block_size: tuple[int, int, int], + launch_bounds: int | str | None, + launch_factor: int | None = None, +) -> str | None: + """Used by the `GPUSetBlockSize.__init__()` method to parse the launch bounds.""" + if launch_bounds is None and launch_factor is None: + return None + elif launch_bounds is None and launch_factor is not None: + return str(int(launch_factor) * block_size[0] * block_size[1] * block_size[2]) + elif launch_bounds is not None and launch_factor is None: + assert isinstance(launch_bounds, (str, int)) + return str(launch_bounds) + else: + raise ValueError("Specified both `launch_bounds` and `launch_factor`.") @dace_properties.make_properties class GPUSetBlockSize(dace_transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. - The transformation will apply to all Maps that have a GPU schedule, regardless - of their dimensionality. + The `block_size` is either a sequence, of up to three integers or a string + of up to three numbers, separated by comma (`,`). The first number is the size + of the block in `x` direction, the second for the `y` direction and the third + for the `z` direction. Missing values will be filled with `1`. - The `gpu_block_size` is either a sequence, of up to three integers or a string - of up to three numbers, separated by comma (`,`). - The first number is the size of the block in `x` direction, the second for the - `y` direction and the third for the `z` direction. Missing values will be filled - with `1`. + A different value for the GPU block size and launch bound can be specified for + maps of dimension 1, 2 or 3 (all maps with higher dimensions are considered + three dimensional). If no value is specified then the block size `(32, 1, 1)` + will be used an no launch bound will be be emitted. Args: - block_size: The size of a thread block on the GPU. - launch_bounds: The value for the launch bound that should be used. - launch_factor: If no `launch_bounds` was given use the number of threads - in a block multiplied by this number. + block_size_Xd: The size of a thread block on the GPU for `X` dimensional maps. + launch_bounds_Xd: The value for the launch bound that should be used for `X` + dimensional maps. + launch_factor_Xd: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number, for maps of dimension `X`. - Todo: - Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. + Note: + - You should use the `gt_set_gpu_blocksize()` function. + - "Over specification" is ignored, i.e. if `(32, 3, 1)` is passed as block + size for 1 dimensional maps, then it is changed to `(32, 1, 1)`. """ - block_size = dace_properties.Property( - dtype=None, - allow_none=False, - default=(32, 1, 1), - setter=_gpu_block_parser, - getter=_gpu_block_getter, - desc="Size of the block size a GPU Map should have.", - ) + _block_size_default: Final[tuple[int, int, int]] = (32, 1, 1) - launch_bounds = dace_properties.Property( + block_size_1d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(1), + getter=_make_gpu_block_getter_for(1), + desc="Block size for 1 dimensional GPU maps.", + ) + launch_bounds_1d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 1 dimensional map.", + ) + block_size_2d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(2), + getter=_make_gpu_block_getter_for(2), + desc="Block size for 2 dimensional GPU maps.", + ) + launch_bounds_2d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 2 dimensional map.", + ) + block_size_3d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(3), + getter=_make_gpu_block_getter_for(3), + desc="Block size for 3 dimensional GPU maps.", + ) + launch_bounds_3d = dace_properties.Property( dtype=str, allow_none=True, default=None, - desc="Set the launch bound property of the map.", + desc="Set the launch bound property for 3 dimensional map.", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + # Pattern matching + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, - block_size: Sequence[int | str] | str | None = None, - launch_bounds: int | str | None = None, - launch_factor: int | None = None, + block_size_1d: Sequence[int | str] | str | None = None, + block_size_2d: Sequence[int | str] | str | None = None, + block_size_3d: Sequence[int | str] | str | None = None, + launch_bounds_1d: int | str | None = None, + launch_bounds_2d: int | str | None = None, + launch_bounds_3d: int | str | None = None, + launch_factor_1d: int | None = None, + launch_factor_2d: int | None = None, + launch_factor_3d: int | None = None, ) -> None: super().__init__() - if block_size is not None: - self.block_size = block_size - - if launch_factor is not None: - assert launch_bounds is None - self.launch_bounds = str( - int(launch_factor) * self.block_size[0] * self.block_size[1] * self.block_size[2] - ) - elif launch_bounds is None: - self.launch_bounds = None - elif isinstance(launch_bounds, (str, int)): - self.launch_bounds = str(launch_bounds) - else: - raise TypeError( - f"Does not know how to parse '{launch_bounds}' as 'launch_bounds' argument." - ) + if block_size_1d is not None: + self.block_size_1d = block_size_1d + if block_size_2d is not None: + self.block_size_2d = block_size_2d + if block_size_3d is not None: + self.block_size_3d = block_size_3d + self.launch_bounds_1d = _gpu_launch_bound_parser( + self.block_size_1d, launch_bounds_1d, launch_factor_1d + ) + self.launch_bounds_2d = _gpu_launch_bound_parser( + self.block_size_2d, launch_bounds_2d, launch_factor_2d + ) + self.launch_bounds_3d = _gpu_launch_bound_parser( + self.block_size_3d, launch_bounds_3d, launch_factor_3d + ) @classmethod def expressions(cls) -> Any: @@ -266,7 +514,6 @@ def can_be_applied( - If the map is at global scope. - If if the schedule of the map is correct. """ - scope = graph.scope_dict() if scope[self.map_entry] is not None: return False @@ -282,35 +529,69 @@ def apply( sdfg: dace.SDFG, ) -> None: """Modify the map as requested.""" - self.map_entry.map.gpu_block_size = self.block_size - if self.launch_bounds is not None: # Note empty string has a meaning in DaCe - self.map_entry.map.gpu_launch_bounds = self.launch_bounds + gpu_map: dace_nodes.Map = self.map_entry.map + if len(gpu_map.params) == 1: + block_size = self.block_size_1d + launch_bounds = self.launch_bounds_1d + elif len(gpu_map.params) == 2: + block_size = self.block_size_2d + launch_bounds = self.launch_bounds_2d + else: + block_size = self.block_size_3d + launch_bounds = self.launch_bounds_3d + gpu_map.gpu_block_size = block_size + if launch_bounds is not None: # Note: empty string has a meaning in DaCe + gpu_map.gpu_launch_bounds = launch_bounds @dace_properties.make_properties -class TrivialGPUMapPromoter(dace_transformation.SingleStateTransformation): - """Serial Map promoter for empty GPU maps. +class TrivialGPUMapElimination(dace_transformation.SingleStateTransformation): + """Eliminate certain kind of trivial GPU maps. - In CPU mode a Tasklet can be outside of a map, however, this is not - possible in GPU mode. For this reason DaCe wraps such Tasklets in a - trivial Map. - This transformation will look for such Maps and promote them, such - that they can be fused with downstream maps. + A tasklet outside of map can not write to GPU memory, this can only be done + from within a map (a scalar is possible). For that reason DaCe's GPU + transformation wraps such tasklets in trivial maps. + Under certain condition the transformation will fuse the trivial tasklet with + a downstream (serial) map. + + Args: + do_not_fuse: If `True` then the maps are not fused together. + only_gpu_maps: Only apply to GPU maps; `True` by default. Note: - This transformation should not be run on its own, instead it is run within the context of `gt_gpu_transformation()`. - This transformation must be run after the GPU Transformation. - - Currently the transformation does not do the fusion on its own. - Instead map fusion must be run afterwards. - - The transformation assumes that the upper Map is a trivial Tasklet. - Which should be the majority of all cases. """ + only_gpu_maps = dace_properties.Property( + dtype=bool, + default=True, + desc="Only promote maps that are GPU maps (debug option).", + ) + do_not_fuse = dace_properties.Property( + dtype=bool, + default=False, + desc="Only perform the promotion, do not fuse.", + ) + # Pattern Matching - trivial_map_exit = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - second_map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + trivial_map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + second_map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + do_not_fuse: Optional[bool] = None, + only_gpu_maps: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if only_gpu_maps is not None: + self.only_gpu_maps = only_gpu_maps + if do_not_fuse is not None: + self.do_not_fuse = do_not_fuse @classmethod def expressions(cls) -> Any: @@ -332,63 +613,118 @@ def can_be_applied( The tests includes: - Schedule of the maps. - If the map is trivial. - - If the trivial map was not used to define a symbol. - - Intermediate access node can only have in and out degree of 1. - - The trivial map exit can only have one output. + - Tests if the maps can be fused. """ trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit trivial_map: dace_nodes.Map = trivial_map_exit.map trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) second_map: dace_nodes.Map = self.second_map_entry.map - access_node: dace_nodes.AccessNode = self.access_node # The kind of maps we are interested only have one parameter. if len(trivial_map.params) != 1: return False - - # Check if it is a GPU map - for map_to_check in [trivial_map, second_map]: - if map_to_check.schedule not in [ - dace.dtypes.ScheduleType.GPU_Device, - dace.dtypes.ScheduleType.GPU_Default, - ]: - return False - - # Check if the map is trivial. for rng in trivial_map.range.ranges: if rng[0] != rng[1]: return False - # Now we have to ensure that the symbol is not used inside the scope of the - # map, if it is, then the symbol is just there to define a symbol. - scope_view = graph.scope_subgraph( - trivial_map_entry, - include_entry=False, - include_exit=False, - ) - if any(map_param in scope_view.free_symbols for map_param in trivial_map.params): - return False + # If we do not not fuse, then the second map can not be trivial. + # If we would not prevent that case then we would match these two + # maps again and again. + if self.do_not_fuse and len(second_map.params) <= 1: + for rng in second_map.range.ranges: + if rng[0] == rng[1]: + return False + + # We now check that the Memlets do not depend on the map parameter. + # This is important for the `can_be_applied_to()` check we do below + # because we can avoid calling the replace function. + scope = graph.scope_subgraph(trivial_map_entry) + trivial_map_param: str = trivial_map.params[0] + for edge in scope.edges(): + if trivial_map_param in edge.data.free_symbols: + return False - # ensuring that the trivial map exit and the intermediate node have degree - # one is a cheap way to ensure that the map can be merged into the - # second map. - if graph.in_degree(access_node) != 1: - return False - if graph.out_degree(access_node) != 1: - return False - if graph.out_degree(trivial_map_exit) != 1: - return False + # Check if only GPU maps are involved (this is more a testing debug feature). + if self.only_gpu_maps: + for map_to_check in [trivial_map, second_map]: + if map_to_check.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Now we check if the two maps can be fused together. For that we have to + # do a temporary promotion, it is important that we do not perform the + # renaming. If the old symbol is still used, it is used inside a tasklet + # so it would show up (temporarily) as free symbol. + org_trivial_map_params = copy.deepcopy(trivial_map.params) + org_trivial_map_range = copy.deepcopy(trivial_map.range) + try: + self._promote_map(graph, replace_trivail_map_parameter=False) + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + map_exit_1=trivial_map_exit, + intermediate_access_node=self.access_node, + map_entry_2=self.second_map_entry, + ): + return False + finally: + trivial_map.params = org_trivial_map_params + trivial_map.range = org_trivial_map_range return True def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the Map Promoting. - The function essentially copies the parameters and the ranges from the - bottom map to the top one. + The function will first perform the promotion of the trivial map and then + perform the merging of the two maps in one go. """ + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + second_map_entry: dace_nodes.MapEntry = self.second_map_entry + access_node: dace_nodes.AccessNode = self.access_node + + # Promote the maps. + self._promote_map(graph) + + # Perform the fusing if requested. + if not self.do_not_fuse: + gtx_transformations.MapFusionSerial.apply_to( + sdfg=sdfg, + map_exit_1=trivial_map_exit, + intermediate_access_node=access_node, + map_entry_2=second_map_entry, + verify=True, + ) + + def _promote_map( + self, + state: dace.SDFGState, + replace_trivail_map_parameter: bool = True, + ) -> None: + """Performs the map promoting. + + Essentially this function will copy the parameters and the range from + the non trivial map (`self.second_map_entry.map`) to the trivial map + (`self.trivial_map_exit.map`). + + If `replace_trivail_map_parameter` is `True` (the default value), then the + function will also remove the trivial map parameter with its value. + """ + assert isinstance(self.trivial_map_exit, dace_nodes.MapExit) + assert isinstance(self.second_map_entry, dace_nodes.MapEntry) + assert isinstance(self.access_node, dace_nodes.AccessNode) + + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit trivial_map: dace_nodes.Map = self.trivial_map_exit.map + trivial_map_entry: dace_nodes.MapEntry = state.entry_node(trivial_map_exit) second_map: dace_nodes.Map = self.second_map_entry.map + # If requested then replace the map variable with its value. + if replace_trivail_map_parameter: + scope = state.scope_subgraph(trivial_map_entry) + scope.replace(trivial_map.params[0], trivial_map.range[0][0]) + + # Now copy parameter and the ranges from the second to the trivial map. trivial_map.params = copy.deepcopy(second_map.params) trivial_map.range = copy.deepcopy(second_map.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py new file mode 100644 index 0000000000..52f1de3d0c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py @@ -0,0 +1,393 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +import copy + +import dace +from dace import ( + data as dace_data, + dtypes as dace_dtypes, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def gt_create_local_double_buffering( + sdfg: dace.SDFG, +) -> int: + """Modifies the SDFG such that point wise data dependencies are stable. + + Rule 3 of the ADR18, guarantees that if data is input and output to a map, + then it must be a non transient array and it must only have point wise + dependency. This means that every index that is read is also written by + the same thread and no other thread reads or writes to the same location. + However, because the dataflow inside a map is partially asynchronous + it might happen if something is read multiple times, i.e. Tasklets, + the data might already be overwritten. + This function will scan the SDFG for potential cases and insert an + access node to cache this read. This is essentially a double buffer, but + it is not needed that the whole data is stored, but only the working set + of a single thread. + """ + + processed_maps = 0 + for nsdfg in sdfg.all_sdfgs_recursive(): + processed_maps += _create_local_double_buffering_non_recursive(nsdfg) + return processed_maps + + +def _create_local_double_buffering_non_recursive( + sdfg: dace.SDFG, +) -> int: + """Implementation of the point wise transformation. + + This function does not handle nested SDFGs. + """ + # First we call `EdgeConsolidation`, because of that we know that + # every incoming edge of a `MapEntry` refers to distinct data. + # We do this to simplify our implementation. + edge_consolidation = dace_transformation.passes.ConsolidateEdges() + edge_consolidation.apply_pass(sdfg, None) + + processed_maps = 0 + for state in sdfg.states(): + scope_dict = state.scope_dict() + for node in state.nodes(): + if not isinstance(node, dace_nodes.MapEntry): + continue + if scope_dict[node] is not None: + continue + inout_nodes = _check_if_map_must_be_handled( + map_entry=node, + state=state, + sdfg=sdfg, + ) + if inout_nodes is not None: + processed_maps += _add_local_double_buffering_to( + map_entry=node, + inout_nodes=inout_nodes, + state=state, + sdfg=sdfg, + ) + return processed_maps + + +def _add_local_double_buffering_to( + inout_nodes: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> int: + """Adds the double buffering to `map_entry` for `inout_nodes`. + + The function assumes that there is only in incoming edge per data + descriptor at the map entry. If the data is needed multiple times, + then the distribution must be done inside the map. + + The function will now channel all reads to the data descriptor + through an access node, this ensures that the read happens + before the write. + """ + processed_maps = 0 + for inout_node in inout_nodes.values(): + _add_local_double_buffering_to_single_data( + map_entry=map_entry, + inout_node=inout_node, + state=state, + sdfg=sdfg, + ) + processed_maps += 1 + return processed_maps + + +def _add_local_double_buffering_to_single_data( + inout_node: tuple[dace_nodes.AccessNode, dace_nodes.AccessNode], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None: + """Adds the local double buffering for a single data.""" + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + input_node, output_node = inout_node + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert len(input_edges) == 1 + assert len(output_edges) == 1 + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + inner_write_edges = _get_inner_edges(output_edges[0], map_exit, state, True) + + # For now we assume that all read the same, which is checked below. + new_double_inner_buff_shape_raw = dace_symbolic.overapproximate( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state).size() + ) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + squeezed_dims: list[int] = [] # These are the dimensions we removed. + new_double_inner_buff_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_double_inner_buff_shape_raw, input_node.desc(sdfg).shape) + ): + if full_dim_size == 1: # Must be kept! + new_double_inner_buff_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_double_inner_buff_shape.append(proposed_dim_size) + + new_double_inner_buff_name: str = f"__inner_double_buffer_for_{input_node.data}" + # Now generate the intermediate data container. + if len(new_double_inner_buff_shape) == 0: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_scalar( + new_double_inner_buff_name, + dtype=input_node.desc(sdfg).dtype, + transient=True, + storage=dace_dtypes.StorageType.Register, + find_new_name=True, + ) + else: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_transient( + new_double_inner_buff_name, + shape=new_double_inner_buff_shape, + dtype=input_node.desc(sdfg).dtype, + find_new_name=True, + storage=dace_dtypes.StorageType.Register, + ) + new_double_inner_buff_node = state.add_access(new_double_inner_buff_name) + + # Now reroute the data flow through the new access node. + for old_inner_read_edge in inner_read_edges: + # To do handle the case the memlet is "fancy" + state.add_edge( + new_double_inner_buff_node, + None, + old_inner_read_edge.dst, + old_inner_read_edge.dst_conn, + dace.Memlet( + data=new_double_inner_buff_name, + subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + other_subset=copy.deepcopy( + old_inner_read_edge.data.get_dst_subset(old_inner_read_edge, state) + ), + ), + ) + state.remove_edge(old_inner_read_edge) + + # Now create a connection between the map entry and the intermediate node. + state.add_edge( + map_entry, + inner_read_edges[0].src_conn, + new_double_inner_buff_node, + None, + dace.Memlet( + data=input_node.data, + subset=copy.deepcopy( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state) + ), + other_subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + ), + ) + + # To really ensure that a read happens before a write, we have to sequence + # the read first. We do this by connecting the double buffer node with + # empty Memlets to the last row of nodes that writes to the global buffer. + # This is needed to handle the case that some other data path performs the + # write. + # TODO(phimuell): Add a test that only performs this when it is really needed. + for inner_write_edge in inner_write_edges: + state.add_nedge( + new_double_inner_buff_node, + inner_write_edge.src, + dace.Memlet(), + ) + + +def _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node: dace_nodes.AccessNode, + sdfg: dace.SDFG, + known_nodes: dict[str, dace_nodes.AccessNode], +) -> bool: + """Internal function used by `_check_if_map_must_be_handled()` to classify nodes. + + If the function returns `True` it means that the input/output, does not + violates an internal constraint, i.e. can be handled by + `_ensure_that_map_is_pointwise()`. If appropriate the function will add the + node to `known_nodes`. I.e. in case of a transient the function will return + `True` but will not add it to `known_nodes`. + """ + + # This case is indicating that the `ConsolidateEdges` has not fully worked. + # Currently the transformation implementation assumes that this is the + # case, so we can not handle this case. + # TODO(phimuell): Implement this case. + if data_node.data in known_nodes: + return False + data_desc: dace_data.Data = data_node.desc(sdfg) + + # The conflict can only occur for global data, because transients + # are only written once. + if data_desc.transient: + return False + + # Currently we do not handle view, as they need to be traced. + # TODO(phimuell): Implement + if gtx_transformations.util.is_view(data_desc, sdfg): + return False + + # TODO(phimuell): Check if there is a access node on the inner side, then we do not have to do it. + + # Now add the node to the list. + assert all(data_node is not known_node for known_node in known_nodes.values()) + known_nodes[data_node.data] = data_node + return True + + +def _get_inner_edges( + outer_edge: dace.sdfg.graph.MultiConnectorEdge, + scope_node: dace_nodes.MapExit | dace_nodes.MapEntry, + state: dace.SDFG, + outgoing_edge: bool, +) -> list[dace.sdfg.graph.MultiConnectorEdge]: + """Gets the edges on the inside of a map.""" + if outgoing_edge: + assert isinstance(scope_node, dace_nodes.MapExit) + conn_name = outer_edge.src_conn[4:] + return list(state.in_edges_by_connector(scope_node, connector="IN_" + conn_name)) + else: + assert isinstance(scope_node, dace_nodes.MapEntry) + conn_name = outer_edge.dst_conn[3:] + return list(state.out_edges_by_connector(scope_node, connector="OUT_" + conn_name)) + + +def _check_if_map_must_be_handled( + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None | dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]]: + """Check if the map should be processed to uphold rule 3. + + Essentially the function will check if there is a potential read-write + conflict. The function assumes that `ConsolidateEdges` has already run. + + If there is a possible data race the function will return a `dict`, that + maps the name of the data to the access nodes that are used as input and + output to the Map. + + Otherwise, the function returns `None`. It is, however, important that + `None` does not means that there is no possible race condition. It could + also means that the function that implements the buffering, i.e. + `_ensure_that_map_is_pointwise()`, is unable to handle this case. + + Todo: + Improve the function + """ + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + + # Find all the data that is accessed. Views are resolved. + input_datas: dict[str, dace_nodes.AccessNode] = {} + output_datas: dict[str, dace_nodes.AccessNode] = {} + + # Determine which nodes are possible conflicting. + for in_edge in state.in_edges(map_entry): + if in_edge.data.is_empty(): + continue + if not isinstance(in_edge.src, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if in_edge.dst_conn and not in_edge.dst_conn.startswith("IN_"): + # TODO(phimuell): It is very unlikely that a Dynamic Map Range causes + # this particular data race, so we ignore it for the time being. + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=in_edge.src, + sdfg=sdfg, + known_nodes=input_datas, + ): + continue + for out_edge in state.out_edges(map_exit): + if out_edge.data.is_empty(): + continue + if not isinstance(out_edge.dst, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=out_edge.dst, + sdfg=sdfg, + known_nodes=output_datas, + ): + continue + + # Double buffering is only needed if there inout arguments. + inout_datas: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]] = { + dname: (input_datas[dname], output_datas[dname]) + for dname in input_datas + if dname in output_datas + } + if len(inout_datas) == 0: + return None + + # TODO(phimuell): What about the case that some data descriptor needs double + # buffering, but some do not? + for inout_data_name in list(inout_datas.keys()): + input_node, output_node = inout_datas[inout_data_name] + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert ( + len(input_edges) == 1 + ), f"Expected a single connection between input node and map entry, but found {len(input_edges)}." + assert ( + len(output_edges) == 1 + ), f"Expected a single connection between map exit and write back node, but found {len(output_edges)}." + + # If there is only one edge on the inside of the map, that goes into an + # AccessNode, then we assume it is double buffered. + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + if ( + len(inner_read_edges) == 1 + and isinstance(inner_read_edges[0].dst, dace_nodes.AccessNode) + and not gtx_transformations.util.is_view(inner_read_edges[0].dst, sdfg) + ): + inout_datas.pop(inout_data_name) + continue + + inner_read_subsets = [ + inner_read_edge.data.get_src_subset(inner_read_edge, state) + for inner_read_edge in inner_read_edges + ] + assert all(inner_read_subset is not None for inner_read_subset in inner_read_subsets) + inner_write_subsets = [ + inner_write_edge.data.get_dst_subset(inner_write_edge, state) + for inner_write_edge in _get_inner_edges(output_edges[0], map_exit, state, True) + ] + # TODO(phimuell): Also implement a check that the volume equals the size of the subset. + assert all(inner_write_subset is not None for inner_write_subset in inner_write_subsets) + + # For being point wise the subsets must be compatible. The correct check would be: + # - The write sets are unique. + # - For every read subset there exists one matching write subset. It could + # be that there are many equivalent read subsets. + # - For every write subset there exists at least one matching read subset. + # The current implementation only checks if all are the same. + # TODO(phimuell): Implement the real check. + all_inner_subsets = inner_read_subsets + inner_write_subsets + if not all( + all_inner_subsets[0] == all_inner_subsets[i] for i in range(1, len(all_inner_subsets)) + ): + return None + + if len(inout_datas) == 0: + return None + + return inout_datas diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index d7326e1131..d401c06f15 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -59,20 +59,11 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) - independent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - desc="Set of nodes that are independent of the blocking parameter.", - ) - dependent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - desc="Set of nodes that are dependent on the blocking parameter.", - ) + # Set of nodes that are independent of the blocking parameter. + _independent_nodes: Optional[set[dace_nodes.AccessNode]] + _dependent_nodes: Optional[set[dace_nodes.AccessNode]] - outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + outer_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, @@ -86,6 +77,8 @@ def __init__( self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size + self._independent_nodes = None + self._dependent_nodes = None @classmethod def expressions(cls) -> Any: @@ -125,6 +118,8 @@ def can_be_applied( return False if not self.partition_map_output(graph, sdfg): return False + self._independent_nodes = None + self._dependent_nodes = None return True @@ -137,7 +132,6 @@ def apply( Performs the operation described in the doc string. """ - # Now compute the partitions of the nodes. self.partition_map_output(graph, sdfg) @@ -153,10 +147,8 @@ def apply( state=graph, sdfg=sdfg, ) - - # Clear the old partitions - self.independent_nodes = None - self.dependent_nodes = None + self._independent_nodes = None + self._dependent_nodes = None def _prepare_inner_outer_maps( self, @@ -269,8 +261,8 @@ def partition_map_output( """ # Clear the previous partition. - self.independent_nodes = set() - self.dependent_nodes = None + self._independent_nodes = set() + self._dependent_nodes = None while True: # Find all the nodes that we have to classify in this iteration. @@ -279,9 +271,9 @@ def partition_map_output( nodes_to_classify: set[dace_nodes.Node] = { edge.dst for edge in state.out_edges(self.outer_entry) } - for independent_node in self.independent_nodes: + for independent_node in self._independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) - nodes_to_classify.difference_update(self.independent_nodes) + nodes_to_classify.difference_update(self._independent_nodes) # Now classify each node found_new_independent_node = False @@ -294,7 +286,7 @@ def partition_map_output( # Check if the partition exists. if class_res is None: - self.independent_nodes = None + self._independent_nodes = None return False if class_res is True: found_new_independent_node = True @@ -305,10 +297,10 @@ def partition_map_output( # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. - self.dependent_nodes = { + self._dependent_nodes = { edge.dst for edge in state.out_edges(self.outer_entry) - if edge.dst not in self.independent_nodes + if edge.dst not in self._independent_nodes } return True @@ -333,7 +325,7 @@ def _classify_node( Returns: The function returns `True` if `node_to_classify` is considered independent. - In this case the function will add the node to `self.independent_nodes`. + In this case the function will add the node to `self._independent_nodes`. If the function returns `False` the node was classified as a dependent node. The function will return `None` if the node can not be classified, in this case the partition does not exist. @@ -343,23 +335,50 @@ def _classify_node( state: The state containing the map. sdfg: The SDFG that is processed. """ + assert self._independent_nodes is not None # silence MyPy outer_entry: dace_nodes.MapEntry = self.outer_entry # for caching. + outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) + + # The node needs to have an input and output. + if state.in_degree(node_to_classify) == 0 or state.out_degree(node_to_classify) == 0: + return None # We are only able to handle certain kind of nodes, so screening them. if isinstance(node_to_classify, dace_nodes.Tasklet): if node_to_classify.side_effects: - # TODO(phimuell): Think of handling it. return None + + # A Tasklet must write to an AccessNode, because otherwise there would + # be nothing that could be used to cache anything. Furthermore, this + # AccessNode must be outside of the inner loop, i.e. be independent. + # TODO: Make this check stronger to ensure that there is always an + # AccessNode that is independent. + if not all( + isinstance(out_edge.dst, dace_nodes.AccessNode) + for out_edge in state.out_edges(node_to_classify) + if not out_edge.data.is_empty() + ): + return False + elif isinstance(node_to_classify, dace_nodes.AccessNode): # AccessNodes need to have some special properties. node_desc: dace.data.Data = node_to_classify.desc(sdfg) - if isinstance(node_desc, dace.data.View): # Views are forbidden. return None - if node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: - # The access node has to life fully within the scope. + + # The access node inside either has scope lifetime or is a scalar. + if isinstance(node_desc, dace.data.Scalar): + pass + elif node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: return None + + elif isinstance(node_to_classify, dace_nodes.MapEntry): + # We classify `MapEntries` as dependent nodes, we could now start + # looking if the whole map is independent, but it is currently an + # overkill. + return False + else: # Any other node type we can not handle, so the partition can not exist. # TODO(phimuell): Try to handle certain kind of library nodes. @@ -376,29 +395,12 @@ def _classify_node( # for these classification to make sense the partition has to exist in the # first place. - # Either all incoming edges of a node are empty or none of them. If it has - # empty edges, they are only allowed to come from the map entry. - found_empty_edges, found_nonempty_edges = False, False - for in_edge in in_edges: - if in_edge.data.is_empty(): - found_empty_edges = True - if in_edge.src is not outer_entry: - # TODO(phimuell): Lift this restriction. - return None - else: - found_nonempty_edges = True - - # Test if we found a mixture of empty and nonempty edges. - if found_empty_edges and found_nonempty_edges: - return None - assert ( - found_empty_edges or found_nonempty_edges - ), f"Node '{node_to_classify}' inside '{outer_entry}' without an input connection." - - # Requiring that all output Memlets are non empty implies, because we are - # inside a scope, that there exists an output. - if any(out_edge.data.is_empty() for out_edge in state.out_edges(node_to_classify)): - return None + # There are some very small requirements that we impose on the output edges. + for out_edge in state.out_edges(node_to_classify): + # We consider nodes that are directly connected to the outer map exit as + # dependent. This is an implementation detail to avoid some hard cases. + if out_edge.dst is outer_exit: + return False # Now we have ensured that the partition exists, thus we will now evaluate # if the node is independent or dependent. @@ -413,7 +415,7 @@ def _classify_node( # Now we have to look at incoming edges individually. # We will inspect the subset of the Memlet to see if they depend on the # block variable. If this loop ends normally, then we classify the node - # as independent and the node is added to `independent_nodes`. + # as independent and the node is added to `_independent_nodes`. for in_edge in in_edges: memlet: dace.Memlet = in_edge.data src_subset: dace_subsets.Subset | None = memlet.src_subset @@ -436,11 +438,11 @@ def _classify_node( # The edge must either originate from `outer_entry` or from an independent # node if not it is dependent. - if not (in_edge.src is outer_entry or in_edge.src in self.independent_nodes): + if not (in_edge.src is outer_entry or in_edge.src in self._independent_nodes): return False # Loop ended normally, thus we classify the node as independent. - self.independent_nodes.add(node_to_classify) + self._independent_nodes.add(node_to_classify) return True def _rewire_map_scope( @@ -467,116 +469,138 @@ def _rewire_map_scope( state: The state of the map. sdfg: The SDFG we operate on. """ + assert self._independent_nodes is not None and self._dependent_nodes is not None # Contains the nodes that are already have been handled. relocated_nodes: set[dace_nodes.Node] = set() # We now handle all independent nodes, this means that all of their - # _output_ edges have to go through the new inner map and the Memlets need - # modifications, because of the block parameter. - for independent_node in self.independent_nodes: - for out_edge in state.out_edges(independent_node): + # _output_ edges have to go through the new inner map and the Memlets + # need modifications, because of the block parameter. + for independent_node in self._independent_nodes: + for out_edge in list(state.out_edges(independent_node)): edge_dst: dace_nodes.Node = out_edge.dst relocated_nodes.add(edge_dst) # If destination of this edge is also independent we do not need # to handle it, because that node will also be before the new # inner serial map. - if edge_dst in self.independent_nodes: + if edge_dst in self._independent_nodes: continue # Now split `out_edge` such that it passes through the new inner entry. # We do not need to modify the subsets, i.e. replacing the variable # on which we block, because the node is independent and the outgoing # new inner map entry iterate over the blocked variable. - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=state, - edge=out_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, + if out_edge.data.is_empty(): + # `out_edge` is an empty Memlet that ensures its source, which is + # independent, is sequenced before its destination, which is + # dependent. We now have to split it into two. + # TODO(phimuell): Can we remove this edge? Is the map enough to + # ensure proper sequencing? + new_in_conn = None + new_out_conn = None + new_memlet_outside = dace.Memlet() + + elif not isinstance(independent_node, dace_nodes.AccessNode): + # For syntactical reasons there must be an access node on the + # outside of the (inner) scope, that acts as cache. The + # classification and this preconditions on SDFG should ensure + # that, but there are a few super hard edge cases. + # TODO(phimuell): Add an intermediate here in this case + raise NotImplementedError() + + else: + # NOTE: This creates more connections that are ultimately + # necessary. However, figuring out which one to use and if + # it would be valid, is very complicated, so we don't do it. + new_map_conn = inner_entry.next_connector(try_name=out_edge.data.data) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_outside = dace.Memlet.from_array( + out_edge.data.data, sdfg.arrays[out_edge.data.data] + ) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + + state.add_edge( + out_edge.src, + out_edge.src_conn, + inner_entry, + new_in_conn, + new_memlet_outside, ) - # TODO(phimuell): Check if there might be a subset error. state.add_edge( inner_entry, - "OUT_" + new_map_conn, + new_out_conn, out_edge.dst, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) + state.remove_edge(out_edge) # Now we handle the dependent nodes, they differ from the independent nodes - # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. - for dependent_node in self.dependent_nodes: + # in that they _after_ the new inner map entry. Thus, we have to modify + # their incoming edges. + for dependent_node in self._dependent_nodes: for in_edge in state.in_edges(dependent_node): edge_src: dace_nodes.Node = in_edge.src - # Since the independent nodes were already processed, and they process - # their output we have to check for this. We do this by checking if - # the source of the edge is the new inner map entry. + # The incoming edge of a dependent node (before any processing) either + # starts at: + # - The outer map. + # - An other dependent node. + # - An independent node. + # The last case was already handled by the loop above. if edge_src is inner_entry: + # Edge originated originally at an independent node, but was + # already handled by the loop above. assert dependent_node in relocated_nodes - continue - # A dependent node has at least one connection to the outer map entry. - # And these are the only connections that we must handle, since other - # connections come from independent nodes, and were already handled - # or are inner nodes. - if edge_src is not outer_entry: - continue - - # If we encounter an empty Memlet we just just attach it to the - # new inner map entry. Note the partition function ensures that - # either all edges are empty or non. - if in_edge.data.is_empty(): - assert ( - edge_src is outer_entry - ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + elif edge_src is not outer_entry: + # Edge originated at an other dependent node. There is nothing + # that we have to do. + # NOTE: We can not test if `edge_src` is in `self._dependent_nodes` + # because it only contains the dependent nodes that are directly + # connected to the map entry. + assert edge_src not in self._independent_nodes + + elif in_edge.data.is_empty(): + # The dependent node has an empty Memlet to the other map. + # Since the inner map is sequenced after the outer map, + # we will simply reconnect the edge to the inner map. + # TODO(phimuell): Are there situations where this makes problems. dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) - continue - # Because of the definition of a dependent node and the processing - # order, their incoming edges either point to the outer map or - # are already handled. - assert ( - edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." - edge_conn: str = in_edge.src_conn[4:] - - # Must be before the handling of the modification below - # Note that this will remove the original edge from the SDFG. - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_src=inner_entry, - new_src_conn="OUT_" + edge_conn, - ) - - # In a valid SDFG only one edge can go into an input connector of a Map. - if "IN_" + edge_conn in inner_entry.in_connectors: - # We have found this edge multiple times already. - # To ensure that there is no error, we will create a new - # Memlet that reads the whole array. - piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) - data_name = piping_edge.data.data - piping_edge.data = dace.Memlet.from_array( - data_name, sdfg.arrays[data_name], piping_edge.data.wcr + elif edge_src is outer_entry: + # This dependent node originated at the outer map. Thus we have to + # split the edge, such that it now passes through the inner map. + new_map_conn = inner_entry.next_connector(try_name=in_edge.src_conn[4:]) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_inner = dace.Memlet.from_array( + in_edge.data.data, sdfg.arrays[in_edge.data.data] ) - - else: - # This is the first time we found this connection. - # so we just create the edge. state.add_edge( - outer_entry, - "OUT_" + edge_conn, + in_edge.src, + in_edge.src_conn, inner_entry, - "IN_" + edge_conn, + new_in_conn, + new_memlet_inner, + ) + state.add_edge( + inner_entry, + new_out_conn, + in_edge.dst, + in_edge.dst_conn, copy.deepcopy(in_edge.data), ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + state.remove_edge(in_edge) + + else: + raise NotImplementedError("Unknown node configuration.") # In certain cases it might happen that we need to create an empty # Memlet between the outer map entry and the inner one. @@ -593,7 +617,7 @@ def _rewire_map_scope( # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. for in_edge in state.in_edges(outer_exit): - edge_conn = in_edge.dst_conn[3:] + edge_conn = inner_exit.next_connector(in_edge.dst_conn[3:]) dace_helpers.redirect_edge( state=state, edge=in_edge, @@ -610,5 +634,9 @@ def _rewire_map_scope( inner_exit.add_in_connector("IN_" + edge_conn) inner_exit.add_out_connector("OUT_" + edge_conn) + # There is an invalid cache state in the SDFG, that makes the memlet + # propagation fail, to clear the cache we call the hash function. + # See: https://github.com/spcl/dace/issues/1703 + _ = sdfg.hash_sdfg() # TODO(phimuell): Use a less expensive method. dace.sdfg.propagation.propagate_memlets_state(sdfg, state) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index ec33e7ea63..eceb07ed82 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -6,89 +6,106 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Implements helper functions for the map fusion transformations. +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements Helper functionaliyies for map fusion -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. """ -import functools -import itertools -from typing import Any, Optional, Sequence, Union + +# ruff: noqa + +import copy +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, Callable, TypeAlias import dace -from dace import ( - data as dace_data, - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes, validation as dace_validation -from dace.transformation import helpers as dace_helpers - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import util - - -@dace_properties.make_properties -class MapFusionHelper(dace_transformation.SingleStateTransformation): - """Contains common part of the fusion for parallel and serial Map fusion. - - The transformation assumes that the SDFG obeys the principals outlined in - [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md). - The main advantage of this structure is, that it is rather easy to determine - if a transient is used anywhere else. This check, performed by - `is_interstate_transient()`. It is further speeded up by cashing some computation, - thus such an object should not be used after interstate optimizations were applied - to the SDFG. +from dace import data, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, nodes, validation +from dace.transformation import helpers + +FusionCallback: TypeAlias = Callable[ + ["MapFusionHelper", nodes.MapEntry, nodes.MapEntry, dace.SDFGState, dace.SDFG, bool], bool +] +"""Callback for the map fusion transformation to check if a fusion should be performed. +""" + + +@properties.make_properties +class MapFusionHelper(transformation.SingleStateTransformation): + """Common parts of the parallel and serial map fusion transformation. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Note: + If `strict_dataflow` mode is enabled then the transformation will not remove + _direct_ data flow dependency from the graph. Furthermore, the transformation + will not remove size 1 dimensions of intermediate it creates. + This is a compatibility mode, that will limit the applicability of the + transformation, but might help transformations that do not fully analyse + the graph. """ - only_toplevel_maps = dace_properties.Property( + only_toplevel_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are in the top level.", ) - only_inner_maps = dace_properties.Property( + only_inner_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - shared_transients = dace_properties.DictProperty( - key_type=dace.SDFG, - value_type=set[str], - default=None, - allow_none=True, - desc="Maps SDFGs to the set of array transients that can not be removed. " - "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", + strict_dataflow = properties.Property( + dtype=bool, + default=False, + desc="If `True` then the transformation will ensure a more stricter data flow.", ) + # Callable that can be specified by the user, if it is specified, it should be + # a callable with the same signature as `can_be_fused()`. If the function returns + # `False` then the fusion will be rejected. + _apply_fusion_callback: Optional[FusionCallback] + + # Maps SDFGs to the set of data that can not be removed, + # because they transmit data _between states_, such data will be made 'shared'. + # This variable acts as a cache, and is managed by 'is_shared_data()'. + _shared_data: Dict[SDFG, Set[str]] + def __init__( self, only_inner_maps: Optional[bool] = None, only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + apply_fusion_callback: Optional[FusionCallback] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) + self._shared_data = {} + self._apply_fusion_callback = None if only_toplevel_maps is not None: self.only_toplevel_maps = bool(only_toplevel_maps) if only_inner_maps is not None: self.only_inner_maps = bool(only_inner_maps) - self.shared_transients = {} + if strict_dataflow is not None: + self.strict_dataflow = bool(strict_dataflow) + if apply_fusion_callback is not None: + self._apply_fusion_callback = apply_fusion_callback @classmethod def expressions(cls) -> bool: - raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") def can_be_fused( self, - map_entry_1: dace_nodes.MapEntry, - map_entry_2: dace_nodes.MapEntry, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, @@ -97,13 +114,11 @@ def can_be_fused( This function only checks constrains that are common between serial and parallel map fusion process, which includes: + - The registered callback, if specified. - The scope of the maps. - The scheduling of the maps. - The map parameters. - However, for performance reasons, the function does not check if the node - decomposition exists. - Args: map_entry_1: The entry of the first (in serial case the top) map. map_exit_2: The entry of the second (in serial case the bottom) map. @@ -111,6 +126,13 @@ def can_be_fused( sdfg: The SDFG itself. permissive: Currently unused. """ + # Consult the callback if defined. + if self._apply_fusion_callback is not None: + if not self._apply_fusion_callback( + self, map_entry_1, map_entry_2, graph, sdfg, permissive + ): + return False + if self.only_inner_maps and self.only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") @@ -128,26 +150,22 @@ def can_be_fused( elif self.only_toplevel_maps: if scope[map_entry_1] is not None: return False - # TODO(phimuell): Figuring out why this is here. - elif util.is_nested_sdfg(sdfg): - return False - # We will now check if there exists a "remapping" that we can use. - # NOTE: The serial map promoter depends on the fact that this is the - # last check. - if not self.map_parameter_compatible( - map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg + # We will now check if there exists a remapping of the map parameter + if ( + self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) + is None ): return False return True - @staticmethod def relocate_nodes( - from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - state: dace.SDFGState, - sdfg: dace.SDFG, + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, ) -> None: """Move the connectors and edges from `from_node` to `to_nodes` node. @@ -156,6 +174,7 @@ def relocate_nodes( once for the entry and then for the exit. While it does not remove the node themselves if guarantees that the `from_node` has degree zero. + The function assumes that the parameter renaming was already done. Args: from_node: Node from which the edges should be removed. @@ -165,22 +184,23 @@ def relocate_nodes( """ # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in filter(lambda e: e.data.is_empty(), state.out_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in filter(lambda e: e.data.is_empty(), state.in_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) # We now ensure that there is only one empty Memlet from the `to_node` to any other node. # Although it is allowed, we try to prevent it. - empty_targets: set[dace_nodes.Node] = set() - for empty_edge in filter(lambda e: e.data.is_empty(), state.all_edges(to_node)): + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): if empty_edge.dst in empty_targets: state.remove_edge(empty_edge) empty_targets.add(empty_edge.dst) # We now determine which edges we have to migrate, for this we are looking at # the incoming edges, because this allows us also to detect dynamic map ranges. - for edge_to_move in state.in_edges(from_node): + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): assert isinstance(edge_to_move.dst_conn, str) if not edge_to_move.dst_conn.startswith("IN_"): @@ -200,36 +220,32 @@ def relocate_nodes( raise RuntimeError( # Might fail because of out connectors. f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." ) - dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) - # There is no other edge that we have to consider, so we just end here - continue - - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in state.in_edges_by_connector(from_node, "IN_" + old_conn): - dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in state.out_edges_by_connector(from_node, "OUT_" + old_conn): - dace_helpers.redirect_edge( - state, e, new_src=to_node, new_src_conn="OUT_" + new_conn - ) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) # Check if we succeeded. if state.out_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", sdfg, sdfg.node_id(state), ) if state.in_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", sdfg, sdfg.node_id(state), @@ -237,330 +253,442 @@ def relocate_nodes( assert len(from_node.in_connectors) == 0 assert len(from_node.out_connectors) == 0 - @staticmethod - def map_parameter_compatible( - map_1: dace_nodes.Map, - map_2: dace_nodes.Map, - state: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - ) -> bool: - """Checks if the parameters of `map_1` are compatible with `map_2`. + def find_parameter_remapping( + self, first_map: nodes.Map, second_map: nodes.Map + ) -> Union[Dict[str, str], None]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at all is _needed_, i.e. all parameter have the same name and + range, then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. - The check follows the following rules: - - The names of the map variables must be the same, i.e. no renaming - is performed. - - The ranges must be the same. + Args: + first_map: The first map (these parameters will be replaced). + second_map: The second map, these parameters acts as source. """ - range_1: dace_subsets.Range = map_1.range - params_1: Sequence[str] = map_1.params - range_2: dace_subsets.Range = map_2.range - params_2: Sequence[str] = map_2.params - - # The maps are only fuseable if we have an exact match in the parameter names - # this is because we do not do any renaming. This is in accordance with the - # rules. - if set(params_1) != set(params_2): - return False - # Maps the name of a parameter to the dimension index - param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} - param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) # noqa: E731 + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.discard(param) + unused_first_params.discard(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None - # To fuse the two maps the ranges must have the same ranges - for pname in params_1: - idx_1 = param_dim_map_1[pname] - idx_2 = param_dim_map_2[pname] - # TODO(phimuell): do we need to call simplify? - if range_1[idx_1] != range_2[idx_2]: - return False + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping - return True + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + Args: + first_map: The first map (these are the final parameter). + second_map: The second map, this map will be replaced. + second_map_entry: The entry node of the second map. + state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=first_map, second_map=second_map + ) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) - def is_interstate_transient( + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + def is_shared_data( self, - transient: Union[str, dace_nodes.AccessNode], + data: nodes.AccessNode, sdfg: dace.SDFG, - state: dace.SDFGState, ) -> bool: - """Tests if `transient` is an interstate transient, an can not be removed. - - Essentially this function checks if a transient might be needed in a - different state in the SDFG, because it transmit information from - one state to the other. - If only the name of the data container is passed the function will - first look for an corresponding access node. + """Tests if `data` is interstate data, an can not be removed. - The set of these "interstate transients" is computed once per SDFG. - The result is then cached internally for later reuse. + Interstate data is used to transmit data between multiple state or by + extension within the state. Thus it must be classified as a shared output. + This function will go through the SDFG to and collect the names of all data + container that should be classified as shared. Note that this is an over + approximation as it does not take the location into account, i.e. "is no longer + used". Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. - state: If given the state the node is located in. + + Note: + The function computes the this set once for every SDFG and then caches it. + There is no mechanism to detect if the cache must be evicted. However, + as long as no additional data is added, there is no problem. """ + if sdfg not in self._shared_data: + self._compute_shared_data(sdfg) + return data.data in self._shared_data[sdfg] - # The following builds upon the HACK MD document and not on ADR0018. - # Therefore the numbers are slightly different, but both documents - # essentially describes the same SDFG. - # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access dace_nodes. - # Because of rule 3 we also include all scalars in this set, as an over - # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink dace_nodes. - - # See if we have already computed the set - if sdfg in self.shared_transients: - shared_sdfg_transients: set[str] = self.shared_transients[sdfg] - else: - # SDFG is not known so we have to compute the set. - shared_sdfg_transients = set() - for state_to_scan in sdfg.all_states(): - # TODO(phimuell): Use `all_nodes_recursive()` once it is available. - shared_sdfg_transients.update( - [ - node.data - for node in itertools.chain( - state_to_scan.source_nodes(), state_to_scan.sink_nodes() - ) - if isinstance(node, dace_nodes.AccessNode) - and sdfg.arrays[node.data].transient - ] + def _compute_shared_data( + self, + sdfg: dace.SDFG, + ) -> None: + """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + + See the documentation for `self.is_shared_data()` for a description. + + Args: + sdfg: The SDFG for which the set of shared data should be computed. + """ + # Shared data of this SDFG. + shared_data: Set[str] = set() + + # All global data can not be removed, so it must always be shared. + for data_name, data_desc in sdfg.arrays.items(): + if not data_desc.transient: + shared_data.add(data_name) + elif isinstance(data_desc, dace.data.Scalar): + shared_data.add(data_name) + + # We go through all states and classify the nodes/data: + # - Data is referred to in different states. + # - The access node is a view (both have to survive). + # - Transient sink or source node. + # - The access node has output degree larger than 1 (input degrees larger + # than one, will always be partitioned as shared anyway). + prevously_seen_data: Set[str] = set() + interstate_read_symbols: Set[str] = set() + for state in sdfg.nodes(): + for access_node in state.data_nodes(): + if access_node.data in shared_data: + # The data was already classified to be shared data + pass + + elif access_node.data in prevously_seen_data: + # We have seen this data before, either in this state or in + # a previous one, but we did not classifies it as shared back then + shared_data.add(access_node.data) + + if state.in_degree(access_node) == 0: + # (Transient) sink nodes are used in other states, or simplify + # will get rid of them. + shared_data.add(access_node.data) + + elif ( + state.out_degree(access_node) != 1 + ): # state.out_degree() == 0 or state.out_degree() > 1 + # The access node is either a source node (it is shared in another + # state) or the node has a degree larger than one, so it is used + # in this state somewhere else. + shared_data.add(access_node.data) + + elif self.is_view(node=access_node, sdfg=sdfg): + # To ensure that the write to the view happens, both have to be shared. + viewed_data: str = self.track_view( + view=access_node, state=state, sdfg=sdfg + ).data + shared_data.update([access_node.data, viewed_data]) + prevously_seen_data.update([access_node.data, viewed_data]) + + else: + # The node was not classified as shared data, so we record that + # we saw it. Note that a node that was immediately classified + # as shared node will never be added to this set, but a data + # that was found twice will be inside this list. + prevously_seen_data.add(access_node.data) + + # Now we are collecting all symbols that interstate edges read from. + for edge in sdfg.edges(): + interstate_read_symbols.update(edge.data.read_symbols()) + + # We also have to keep everything the edges referrers to and is an array. + shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) + + # Update the internal cache + self._shared_data[sdfg] = shared_data + + def _compute_multi_write_data( + self, + state: SDFGState, + sdfg: SDFG, + ) -> Set[str]: + """Computes data inside a _single_ state, that is written multiple times. + + Essentially this function computes the set of data that does not follow + the single static assignment idiom. The function also resolves views. + If an access node, refers to a view, not only the view itself, but also + the data it refers to is added to the set. + + Args: + state: The state that should be examined. + sdfg: The SDFG object. + + Note: + This information is used by the partition function (in case strict data + flow mode is enabled), in strict data flow mode only. The current + implementation is rather simple as it only checks if a data is written + to multiple times in the same state. + """ + data_written_to: Set[str] = set() + multi_write_data: Set[str] = set() + + for access_node in state.data_nodes(): + if state.in_degree(access_node) == 0: + continue + if access_node.data in data_written_to: + multi_write_data.add(access_node.data) + elif self.is_view(access_node, sdfg): + # This is an over approximation. + multi_write_data.update( + [access_node.data, self.track_view(access_node, state, sdfg).data] ) - self.shared_transients[sdfg] = shared_sdfg_transients - - if isinstance(transient, str): - name = transient - matching_access_nodes = [node for node in state.data_nodes() if node.data == name] - # Rule 8: There is only one access node per state for data. - assert len(matching_access_nodes) == 1 - transient = matching_access_nodes[0] - else: - assert isinstance(transient, dace_nodes.AccessNode) - name = transient.data + data_written_to.add(access_node.data) + return multi_write_data - desc: dace_data.Data = sdfg.arrays[name] - if not desc.transient: - return True - if isinstance(desc, dace_data.Scalar): - return True # Scalars can not be removed by fusion anyway. + def is_node_reachable_from( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The node that should be located. + """ - # Rule 8: If degree larger than one then it is used within the state. - if state.out_degree(transient) > 1: - return True + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) - # Now we check if it is used in a different state. - return name in shared_sdfg_transients + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() - def partition_first_outputs( + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + def get_access_set( self, - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - ) -> Union[ - tuple[ - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `map_exit_1` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - - - Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - - Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. + scope_node: The scope node that should be evaluated. + state: The state in which we operate. """ - # The three outputs set. - pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) # noqa: E731 + other_node = lambda e: e.src # noqa: E731 + else: + get_edges = lambda node: state.out_edges(node) # noqa: E731 + other_node = lambda e: e.dst # noqa: E731 + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: set[dace_nodes.Node] = set() + return access_set - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): - intermediate_node: dace_nodes.Node = out_edge.dst + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # Now let's look at all nodes that are downstream of the intermediate node. - # This, among other things, will tell us, how we have to handle this node. - downstream_nodes = util.all_nodes_between( - graph=state, - begin=intermediate_node, - end=map_entry_2, + Args: + node: The access node that should be examined. + scope_node: We are only interested in data that flows through this node. + state: The state in which we operate. + sdfg: The SDFG object. + """ + + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset # noqa: E731 + get_inner_edges = lambda e: state.out_edges_by_connector( + scope_node, "OUT_" + e.dst_conn[3:] + ) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset # noqa: E731 + get_inner_edges = lambda e: state.in_edges_by_connector( + scope_node, "IN_" + e.src_conn[4:] ) - # If `downstream_nodes` is `None` this means that `map_entry_2` was never - # reached, thus `intermediate_node` does not enter the second map and - # the node is a pure output node. - if downstream_nodes is None: - pure_outputs.add(out_edge) - continue + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. + found_subsets = copy.deepcopy(found_subsets) + if repl_dict: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(repl_dict, subset.replace) - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None + return found_subsets - # In case the intermediate has more than one entry, all must come from the - # first map, otherwise we can not fuse them. Currently we restrict this - # even further by saying that it has only one incoming Memlet. - if state.in_degree(intermediate_node) != 1: - return None + def is_view( + self, + node: nodes.AccessNode, + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node.desc(sdfg) + return isinstance(node_desc, data.View) - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - inner_collector_edges = list( - state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) - ) - if len(inner_collector_edges) > 1: - return None + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, dace_nodes.AccessNode): - return None - intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, dace_data.View): - return None + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. - # There are some restrictions we have on intermediate dace_nodes. The first one - # is that we do not allow WCR, this is because they need special handling - # which is currently not implement (the DaCe transformation has this - # restriction as well). The second one is that we can reduce the - # intermediate node and only feed a part into the second map, consider - # the case `b = a + 1; return b + 2`, where we have arrays. In this - # example only a single element must be available to the second map. - # However, this is hard to check so we will make a simplification. - # First, we will not check it at the producer, but at the consumer point. - # There we assume if the consumer does _not consume the whole_ - # intermediate array, then we can decompose the intermediate, by setting - # the map iteration index to zero and recover the shape, see - # implementation in the actual fusion routine. - # This is an assumption that is in most cases correct, but not always. - # However, doing it correctly is extremely complex. - for _, produce_edge in util.find_upstream_producers(state, out_edge): - if produce_edge.data.wcr is not None: - return None - - if len(downstream_nodes) == 0: - # There is nothing between intermediate node and the entry of the - # second map, thus the edge belongs either in `\mathbb{S}` or - # `\mathbb{E}`. - - # This is a very special situation, i.e. the access node has many - # different connections to the second map entry, this is a special - # case that we do not handle. - # TODO(phimuell): Handle this case. - if state.out_degree(intermediate_node) != 1: - return None - - # Certain nodes need more than one element as input. As explained - # above, in this situation we assume that we can naturally decompose - # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range or a library node. - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) - for consumer_node, feed_edge in consumers: - # TODO(phimuell): Improve this approximation. - if ( - intermediate_size != 1 - ) and feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range. - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - - # Note that "remove" has a special meaning here, regardless of the - # output of the check function, from within the second map we remove - # the intermediate, it has more the meaning of "do we need to - # reconstruct it after the second map again?" - if self.is_interstate_transient(intermediate_node, sdfg, state): - shared_outputs.add(out_edge) - else: - exclusive_outputs.add(out_edge) - continue + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ - else: - # There is not only a single connection from the intermediate node to - # the second map, but the intermediate has more connections, thus - # the node might belong to the shared output. Of the many different - # possibilities, we only consider a single case: - # - The intermediate has a single connection to the second map, that - # fulfills the restriction outlined above. - # - All other connections have no connection to the second map. - found_second_entry = False - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - for edge in state.out_edges(intermediate_node): - if edge.dst is map_entry_2: - if found_second_entry: # The second map was found again. - return None - found_second_entry = True - consumers = util.find_downstream_consumers(state=state, begin=edge) - for consumer_node, feed_edge in consumers: - if feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - else: - # Ensure that there is no path that leads to the second map. - after_intermdiate_node = util.all_nodes_between( - graph=state, begin=edge.dst, end=map_entry_2 - ) - if after_intermdiate_node is not None: - return None - # If we are here, then we know that the node is a shared output - shared_outputs.add(out_edge) - continue + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError( + f"Failed to determine the direction of the view '{view}' | {curr_edge}." + ) - assert exclusive_outputs or shared_outputs or pure_outputs - assert len(processed_inter_nodes) == sum( - len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] - ) - return (pure_outputs, exclusive_outputs, shared_outputs) + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while self.is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py new file mode 100644 index 0000000000..19412b9dfa --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py @@ -0,0 +1,170 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements the parallel map fusing transformation. + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +from typing import Any, Optional, Set, Union + +import dace +from dace import properties, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from . import map_fusion_helper as mfh + + +@properties.make_properties +class MapFusionParallel(mfh.MapFusionHelper): + """The `MapFusionParallel` transformation allows to merge two parallel maps. + + While the `MapFusionSerial` transformation fuses maps that are sequentially + connected through an intermediate node, this transformation is able to fuse any + two maps that are not sequential and in the same scope. + + Args: + only_if_common_ancestor: Only perform fusion if both Maps share at least one + node as direct ancestor. This will increase the locality of the merge. + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Note: + This transformation only matches the entry nodes of the Map, but will also + modify the exit nodes of the Maps. + """ + + map_entry_1 = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + only_if_common_ancestor = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps share a node as parent.", + ) + + def __init__( + self, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + # This just matches _any_ two Maps inside a state. + state = graph.OrderedMultiDiConnectorGraph() + state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) + return [state] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Checks if the fusion can be done. + + The function checks the general fusing conditions and if the maps are parallel. + """ + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # Check the structural properties of the maps, this will also ensure that + # the two maps are in the same scope and the parameters can be renamed + if not self.can_be_fused( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + graph=graph, + sdfg=sdfg, + permissive=permissive, + ): + return False + + # Since the match expression matches any two Maps, we have to ensure that + # the maps are parallel. The `can_be_fused()` function already verified + # if they are in the same scope. + if not self.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # This assumes that there is only one access node per data container in the state. + ancestors_1: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} + if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): + return False + + return True + + def is_parallel( + self, + graph: SDFGState, + node1: nodes.Node, + node2: nodes.Node, + ) -> bool: + """Tests if `node1` and `node2` are parallel. + + The nodes are parallel if `node2` can not be reached from `node1` and vice versa. + + Args: + graph: The graph to traverse. + node1: The first node to check. + node2: The second node to check. + """ + + # In order to be parallel they must be in the same scope. + scope = graph.scope_dict() + if scope[node1] != scope[node2]: + return False + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): + return False + elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): + return False + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map fusing. + + Essentially, the function relocate all edges from the scope nodes (`MapEntry` + and `MapExit`) of the second map to the scope nodes of the first map. + """ + + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_entry_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py new file mode 100644 index 0000000000..2cdcc455d4 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py @@ -0,0 +1,1007 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements the serial map fusing transformation. + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +import copy +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import dace +from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from . import map_fusion_helper as mfh + + +@properties.make_properties +class MapFusionSerial(mfh.MapFusionHelper): + """Fuse two serial maps together. + + The transformation combines two maps into one that are connected through some + access nodes. Conceptually this transformation removes the exit of the first + or upper map and the entry of the lower or second map and then rewrites the + connections appropriately. Depending on the situation the transformation will + either fully remove or make the intermediate a new output of the second map. + + By default, the transformation does not use the strict data flow mode, see + `MapFusionHelper` for more, however, it might be useful in come cases to enable + it, especially in the context of DaCe's auto optimizer. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Notes: + - This transformation modifies more nodes than it matches. + - After the transformation has been applied simplify should be run to remove + some dead data flow, that was introduced to ensure validity. + - A `MapFusionSerial` object can be initialized and be reused. However, + after new access nodes are added to any state, it is no longer valid + to use the object. + + Todo: + - Consider the case that only shared nodes are created (thus no inspection of + the graph is needed) and make all shared. Then use the dead dataflow + elimination transformation to get rid of the ones we no longer need. + - Increase the applicability. + """ + + map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) + intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [ + dace.sdfg.utils.node_path_graph( + cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2 + ) + ] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - Satisfies general requirements, see `MapFusionHelper.can_be_fused()`. + - Tests if the decomposition exists. + - Tests if there are read write dependencies. + """ + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # This essentially test the structural properties of the two Maps. + if not self.can_be_fused( + map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg + ): + return False + + # Test for read-write conflicts + if self.has_read_write_dependency( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + state=graph, + sdfg=sdfg, + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True + + def has_read_write_dependency( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks two different things. + - The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets. + - The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as inputs or outputs + at the same time. However, the function will not check for read write + conflicts in this set, this is done in the partition function. + + Returns: + `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` + is returned. + + Args: + map_entry_1: The entry node of the first map. + map_entry_2: The entry node of the second map. + state: The state on which we operate. + sdfg: The SDFG on which we operate. + """ + map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) + map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append( + { + self.track_view(node, state, sdfg).data + if self.is_view(node, sdfg) + else node.data + for node in unresolved_set.values() + } + ) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # If there is no overlap in what is (totally) read and written, there will be no conflict. + # This must come before the check of disjoint write. + if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): + return False + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection( + read_map_2.values() + ) + + # If the number are different then a data is accessed through multiple nodes. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can be used as input and output. But we do not allow that + # it is also used as intermediate or exchange data. This is an important check. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # Get the replacement dict for changing the map variables from the subsets of + # the second map. + repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets: List[subsets.Subset] = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=map_entry_1, + state=state, + sdfg=sdfg, + repl_dict=None, + ) + ) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=map_exit_2, + state=state, + sdfg=sdfg, + repl_dict=repl_dict, + ) + ) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # No read write dependency was found. + return False + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + Args: + subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Optional[subsets.Range] = None, + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + Args: + original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + intermediate_desc: The original intermediate data descriptor. + map_params: The parameter of the final map. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + final_offset = subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError( + f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'." + ) + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Compute the renaming that for translating the parameter of the _second_ + # map to the ones used by the first map. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=map_exit_1.map, + second_map=map_entry_2.map, + ) + assert repl_dict is not None + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # These are the data that is written to multiple times in _this_ state. + # If a data is written to multiple time in a state, it could be + # classified as shared. However, it might happen that the node has zero + # degree. This is not a problem as the maps also induced a before-after + # relationship. But some DaCe transformations do not catch this. + # Thus we will never modify such intermediate nodes and fail instead. + if self.strict_dataflow: + multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) + else: + multi_write_data = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + if self.is_view(intermediate_node, sdfg): + return None + + # Checks if the intermediate node refers to data that is accessed by + # _other_ access nodes in _this_ state. If this is the case then never + # touch this intermediate node. + # TODO(phimuell): Technically it would be enough to turn the node into + # a shared output node, because this will still fulfil the dependencies. + # However, some DaCe transformation can not handle this properly, so we + # are _forced_ to reject this node. + if intermediate_node.data in multi_write_data: + return None + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Memlets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view( + producer_edge.src, sdfg + ): + return None + if producer_edge.data.dynamic: + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not map_entry_2: + if self.is_node_reachable_from( + graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2 + ): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE: The subset still uses the old iteration variables. + for inner_consumer_edge in state.out_edges_by_connector( + map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:] + ): + if inner_consumer_edge.data.src_subset is None: + return None + if inner_consumer_edge.data.dynamic: + # TODO(phimuell): Is this restriction necessary, I am not sure. + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert ( + found_second_map + ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if repl_dict: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace( + mapping=repl_dict, replace_callback=consumer_subset.replace + ) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum( + producer_subset.covers(consumer_subset) for producer_subset in producer_subsets + ) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is used somewhere else, either in this or another state. + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here, this is a known behaviour in DaCe. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit_1, nodes.MapExit) + assert isinstance(self.map_entry_2, nodes.MapEntry) + + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(map_exit_1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=map_exit_1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=map_entry_2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + map_exit_2: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + Args: + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_2: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = map_exit_1.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + storage=dtypes.StorageType.Register, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + assert pre_exit_edge.data.data == inter_name + assert pre_exit_edge.data.dst_subset is not None + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + ) + + # Memlets have a lot of additional informations, such as dynamic. + # To ensure that we get all of them, we will now copy them and modify + # the one that was originally there. We also hope that propagate will + # set the rest for us correctly. + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) + + # New we will reroute the output Memlet, thus it will no longer pass + # through the Map exit but through the newly created intermediate. + # NOTE: We will delete the previous edge later. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( + include_self=False + ): + producer_edge = producer_tree.edge + + # Associate the (already existing) Memlet with the new data. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + producer_edge.data.data = new_inter_name + + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now we create a new connection that instead reads from the new + # intermediate, instead of the old one. For this we use the + # old Memlet as template. However it is not fully initialized. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( + include_self=False + ): + assert consumer_tree.edge.data.data == inter_name + + consumer_edge = consumer_tree.edge + consumer_edge.data.data = new_inter_name + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert pre_exit_edge.data.data == inter_name + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index 4b34dd6adc..8fb41c7d0a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -16,12 +16,42 @@ from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util +def gt_set_iteration_order( + sdfg: dace.SDFG, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, + validate: bool = True, + validate_all: bool = False, +) -> Any: + """Set the iteration order of the Maps correctly. + + Modifies the order of the Map parameters such that `leading_dim` + is the fastest varying one, the order of the other dimensions in + a Map is unspecific. `leading_dim` should be the dimensions were + the stride is one. + + Args: + sdfg: The SDFG to process. + leading_dim: The leading dimensions. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + return sdfg.apply_transformations_once_everywhere( + MapIterationOrder( + leading_dims=leading_dim, + ), + validate=validate, + validate_all=validate_all, + ) + + @dace_properties.make_properties class MapIterationOrder(dace_transformation.SingleStateTransformation): """Modify the order of the iteration variables. The iteration order, while irrelevant from an SDFG point of view, is highly - relevant in code, and the fastest varying index ("inner most loop" in CPU or + relevant in code and the fastest varying index ("inner most loop" in CPU or "x block dimension" in GPU) should be associated with the stride 1 dimension of the array. This transformation will reorder the map indexes such that this is the case. @@ -29,9 +59,18 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): While the place of the leading dimension is clearly defined, the order of the other loop indexes, after this transformation is unspecified. + The transformation accepts either a single dimension or a list of dimensions. + In case a list is passed this is interpreted as priorities. + Assuming we have the `leading_dim=[EdgeDim, VertexDim]`, then we have the + following: + - `Map[EdgeDim, KDim, VertexDim]` -> `Map[KDim, VertexDim, EdgeDim]`. + - `Map[VertexDim, KDim]` -> `Map[KDim, VertexDim]`. + - `Map[EdgeDim, KDim]` -> `Map[KDim, EdgeDim]`. + - `Map[CellDim, KDim]` -> `Map[CellDim, KDim]` (no modification). + Args: - leading_dim: A GT4Py dimension object that identifies the dimension that - is supposed to have stride 1. + leading_dim: GT4Py dimensions that are associated with the dimension that is + supposed to have stride 1. If it is a list it is used as a ranking. Note: The transformation does follow the rules outlines in @@ -44,25 +83,33 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): - Maybe also process the parameters to bring them in a canonical order. """ - leading_dim = dace_properties.Property( - dtype=str, + leading_dims = dace_properties.ListProperty( + element_type=str, allow_none=True, - desc="Dimension that should become the leading dimension.", + default=None, + desc="Dimensions that should become the leading dimension.", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, - leading_dim: Optional[Union[gtx_common.Dimension, str]] = None, + leading_dims: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - if isinstance(leading_dim, gtx_common.Dimension): - self.leading_dim = gtx_dace_fieldview_util.get_map_variable(leading_dim) - elif leading_dim is not None: - self.leading_dim = leading_dim + if isinstance(leading_dims, (gtx_common.Dimension, str)): + leading_dims = [leading_dims] + if isinstance(leading_dims, list): + self.leading_dims = [ + leading_dim + if isinstance(leading_dim, str) + else gtx_dace_fieldview_util.get_map_variable(leading_dim) + for leading_dim in leading_dims + ] @classmethod def expressions(cls) -> Any: @@ -80,16 +127,15 @@ def can_be_applied( Essentially the function checks if the selected dimension is inside the map, and if so, if it is on the right place. """ - - if self.leading_dim is None: + if self.leading_dims is None: return False map_entry: dace_nodes.MapEntry = self.map_entry map_params: Sequence[str] = map_entry.map.params - map_var: str = self.leading_dim + processed_dims: set[str] = set(self.leading_dims) - if map_var not in map_params: + if not any(map_param in processed_dims for map_param in map_params): return False - if map_params[-1] == map_var: # Already at the correct location + if self.compute_map_param_order() is None: return False return True @@ -104,22 +150,52 @@ def apply( `self.leading_dim` the last map variable (this is given by the structure of DaCe's code generator). """ + map_object: dace_nodes.Map = self.map_entry.map + new_map_params_order: list[int] = self.compute_map_param_order() # type: ignore[assignment] # Guaranteed to be not `None`. + + def reorder(what: list[Any]) -> list[Any]: + assert isinstance(what, list) + return [what[new_pos] for new_pos in new_map_params_order] + + map_object.params = reorder(map_object.params) + map_object.range.ranges = reorder(map_object.range.ranges) + map_object.range.tile_sizes = reorder(map_object.range.tile_sizes) + + def compute_map_param_order(self) -> Optional[list[int]]: + """Computes the new iteration order of the matched map. + + The function returns a list, the value at index `i` indicates the old dimension + that should be put at the new location. If the order is already correct then + `None` is returned. + """ map_entry: dace_nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params - map_var: str = self.leading_dim - - # This implementation will just swap the variable that is currently the last - # with the one that should be the last. - dst_idx = -1 - src_idx = map_params.index(map_var) - - for to_process in [ - map_entry.map.params, - map_entry.map.range.ranges, - map_entry.map.range.tile_sizes, - ]: - assert isinstance(to_process, list) - src_val = to_process[src_idx] - dst_val = to_process[dst_idx] - to_process[dst_idx] = src_val - to_process[src_idx] = dst_val + org_mapping: dict[str, int] = {map_param: i for i, map_param in enumerate(map_params)} + leading_dims: list[str] = self.leading_dims + + # We divide the map parameters into two groups, the one we care and the others. + map_params_to_order: set[str] = { + map_param for map_param in map_params if map_param in leading_dims + } + + # If there is nothing to order, then we are done. + if not map_params_to_order: + return None + + # We start with all parameters that we ignore/do not care about. + new_map_params: list[str] = [ + map_param for map_param in map_params if map_param not in leading_dims + ] + + # Because how code generation works the leading dimension must be the most + # left one. Because this is also `self.leading_dims[0]` we have to process + # then in reverse order. + for map_param_to_check in reversed(leading_dims): + if map_param_to_check in map_params_to_order: + new_map_params.append(map_param_to_check) + assert len(map_params) == len(new_map_params) + + if map_params == new_map_params: + return None + + return [org_mapping[new_map_param] for new_map_param in new_map_params] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 19818fd3d1..46d46c4bbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -299,9 +299,9 @@ class SerialMapPromoter(BaseMapPromoter): """ # Pattern Matching - exit_first_map = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - entry_second_map = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + exit_first_map = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + entry_second_map = dace_transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: @@ -346,17 +346,11 @@ def _test_if_promoted_maps_can_be_fused( ) -> bool: """This function checks if the promoted maps can be fused by map fusion. - This function assumes that `self.can_be_applied()` returned `True`. + This function assumes that `super().self.can_be_applied()` returned `True`. Args: state: The state in which we operate. sdfg: The SDFG we process. - - Note: - The current implementation uses a very hacky way to test this. - - Todo: - Find a better way of doing it. """ first_map_exit: dace_nodes.MapExit = self.exit_first_map access_node: dace_nodes.AccessNode = self.access_node @@ -373,23 +367,17 @@ def _test_if_promoted_maps_can_be_fused( # This will lead to a promotion of the map, this is needed that # Map fusion can actually inspect them. self.apply(graph=state, sdfg=sdfg) - - # Now create the map fusion object that we can then use to check if - # the fusion is possible or not. - serial_fuser = gtx_transformations.SerialMapFusion( - only_inner_maps=self.only_inner_maps, - only_toplevel_maps=self.only_toplevel_maps, - ) - candidate = { - type(serial_fuser).map_exit1: first_map_exit, - type(serial_fuser).access_node: access_node, - type(serial_fuser).map_entry2: second_map_entry, - } - state_id = sdfg.node_id(state) - serial_fuser.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) - - # Now use the serial fuser to see if fusion would succeed - if not serial_fuser.can_be_applied(state, 0, sdfg): + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + expr_index=0, + options={ + "only_inner_maps": self.only_inner_maps, + "only_toplevel_maps": self.only_toplevel_maps, + }, + map_exit_1=first_map_exit, + intermediate_access_node=access_node, + map_entry_2=second_map_entry, + ): return False finally: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py deleted file mode 100644 index bca5aa2268..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ /dev/null @@ -1,483 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements the serial map fusing transformation. - -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. -""" - -import copy -from typing import Any, Union - -import dace -from dace import ( - dtypes as dace_dtypes, - properties as dace_properties, - subsets as dace_subsets, - symbolic as dace_symbolic, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper - - -@dace_properties.make_properties -class SerialMapFusion(map_fusion_helper.MapFusionHelper): - """Specialized replacement for the map fusion transformation that is provided by DaCe. - - As its name is indicating this transformation is only able to handle Maps that - are in sequence. Compared to the native DaCe transformation, this one is able - to handle more complex cases of connection between the maps. In that sense, it - is much more similar to DaCe's `SubgraphFusion` transformation. - - Things that are improved, compared to the native DaCe implementation: - - Nested Maps. - - Temporary arrays and the correct propagation of their Memlets. - - Top Maps that have multiple outputs. - - Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewrites the connections - appropriately. - - This transformation assumes that an SDFG obeys the structure that is outlined - [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that - reason it is not true replacement of the native DaCe transformation. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - - Notes: - - This transformation modifies more nodes than it matches! - """ - - map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - """Get the match expression. - - The transformation matches the exit node of the top Map that is connected to - an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on the - matched nodes, but more or less on anything that has an incoming connection - from the first Map or an outgoing connection to the second Map entry. - """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the matched Maps can be merged. - - The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constraints. - - The decomposition exists and at least one of the intermediate sets - is not empty. - """ - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - - # This essentially test the structural properties of the two Maps. - if not self.can_be_fused( - map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg - ): - return False - - # Two maps can be serially fused if the node decomposition exists and - # at least one of the intermediate output sets is not empty. The state - # of the pure outputs is irrelevant for serial map fusion. - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, - ) - if output_partition is None: - return False - _, exclusive_outputs, shared_outputs = output_partition - if not (exclusive_outputs or shared_outputs): - return False - return True - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the serial Map fusing. - - The function first computes the map decomposition and then handles the - three sets. The pure outputs are handled by `relocate_nodes()` while - the two intermediate sets are handled by `handle_intermediate_set()`. - - By assumption we do not have to rename anything. - - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. - """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) - - map_exit_1: dace_nodes.MapExit = self.map_exit1 - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - assert output_partition is not None # Make MyPy happy. - pure_outputs, exclusive_outputs, shared_outputs = output_partition - - if len(exclusive_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=exclusive_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=True, - ) - if len(shared_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=shared_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=False, - ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) - if len(pure_outputs) != 0: - self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, - state=graph, - sdfg=sdfg, - ) - - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. - self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, - state=graph, - sdfg=sdfg, - ) - - for node_to_remove in [map_exit_1, map_entry_2]: - assert graph.degree(node_to_remove) == 0 - graph.remove_node(node_to_remove) - - # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map - - @staticmethod - def handle_intermediate_set( - intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - map_exit_2: dace_nodes.MapExit, - is_exclusive_set: bool, - ) -> None: - """This function handles the intermediate sets. - - The function is able to handle both the shared and exclusive intermediate - output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode the intermediate nodes will be fully removed from - the SDFG. While in shared mode the intermediate node will be preserved. - - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and - after this function has run the state is (most likely) invalid. - - Todo: - Rewrite using `MemletTree`. - """ - - # Essentially this function removes the AccessNode between the two maps. - # However, we still need some temporary memory that we can use, which is - # just much smaller, i.e. a scalar. But all Memlets inside the second map - # assumes that the intermediate memory has the bigger shape. - # To fix that we will create this replacement dict that will replace all - # occurrences of the iteration variables of the second map with zero. - # Note that this is still not enough as the dimensionality might be different. - memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - - # Now we will iterate over all intermediate edges and process them. - # If not stated otherwise the comments assume that we run in exclusive mode. - for out_edge in intermediate_outputs: - # This is the intermediate node that, that we want to get rid of. - # In shared mode we want to recreate it after the second map. - inter_node: dace_nodes.AccessNode = out_edge.dst - inter_name = inter_node.data - inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape - - # Now we will determine the shape of the new intermediate. This size of - # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) - ) - if len(pre_exit_edges) != 1: - raise NotImplementedError() - pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[int] = [] # These are the dimensions we removed. - new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - # Order of checks is important! - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - - # This is the name of the new "intermediate" node that we will create. - # It will only have the shape `new_inter_shape` which is basically its - # output within one Map iteration. - # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" - - # Now generate the intermediate data container. - if len(new_inter_shape) == 0: - assert pre_exit_edge.data.subset.num_elements() == 1 - is_scalar = True - new_inter_name, new_inter_desc = sdfg.add_scalar( - new_inter_name, - dtype=inter_desc.dtype, - transient=True, - storage=dace_dtypes.StorageType.Register, - find_new_name=True, - ) - - else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all( - x == 1 for x in new_inter_shape - ) - is_scalar = False - new_inter_name, new_inter_desc = sdfg.add_transient( - new_inter_name, - shape=new_inter_shape, - dtype=inter_desc.dtype, - find_new_name=True, - ) - new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) - - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # we will delete the previous edge later. - pre_exit_memlet: dace.Memlet = pre_exit_edge.data - new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) - - # We might operate on a different array, but the check below, ensures - # that we do not change the direction of the Memlet. - assert pre_exit_memlet.data == inter_name - new_pre_exit_memlet.data = new_inter_name - - # Now we have to modify the subset of the Memlet. - # Before the subset of the Memlet was dependent on the Map variables, - # however, this is no longer the case, as we removed them. This change - # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the below is correct. - new_pre_exit_memlet.replace(memlet_repl) - if is_scalar: - new_pre_exit_memlet.subset = "0" - new_pre_exit_memlet.other_subset = None - else: - new_pre_exit_memlet.subset.pop(squeezed_dims) - - # Now we create the new edge between the producer and the new output - # (the new intermediate node). We will remove the old edge further down. - new_pre_exit_edge = state.add_edge( - pre_exit_edge.src, - pre_exit_edge.src_conn, - new_inter_node, - None, - new_pre_exit_memlet, - ) - - # We just have handled the last Memlet, but we must actually handle the - # whole producer side, i.e. the scope of the top Map. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): - producer_edge = producer_tree.edge - - # Ensure the correctness of the rerouting below. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - - # Will not change the direction, because of test above! - producer_edge.data.data = new_inter_name - producer_edge.data.replace(memlet_repl) - if is_scalar: - producer_edge.data.dst_subset = "0" - elif producer_edge.data.dst_subset is not None: - producer_edge.data.dst_subset.pop(squeezed_dims) - - # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the newly - # created intermediate into the second map. We do this by finding - # the input connectors on the map entry, such that we know where we - # have to reroute inside the Map. - # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: set[str] = set() - for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: - assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.add(inter_node_out_edge.dst_conn) - else: - # If we found another target than the second map entry from the - # intermediate node it means that the node _must_ survive, - # i.e. we are not in exclusive mode. - assert not is_exclusive_set - - # Now we will reroute the connections inside the second map, i.e. - # instead of consuming the old intermediate node, they will now - # consume the new intermediate node. - for in_conn_name in conn_names: - out_conn_name = "OUT_" + in_conn_name[3:] - - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - - # The create the first Memlet to transmit information, within - # the second map, we do this again by copying and modifying - # the original Memlet. - # NOTE: Test above is important to ensure the direction of the - # Memlet and the correctness of the code below. - new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. - - # Now remove the old edge, that started the second map entry. - # Also add the new edge that started at the new intermediate. - state.remove_edge(inner_edge) - new_inner_edge = state.add_edge( - new_inter_node, - None, - inner_edge.dst, - inner_edge.dst_conn, - new_inner_memlet, - ) - - # Now we do subset modification to ensure that nothing failed. - if is_scalar: - new_inner_memlet.src_subset = "0" - elif new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.pop(squeezed_dims) - - # Now clean the Memlets of that tree to use the new intermediate node. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): - consumer_edge = consumer_tree.edge - assert consumer_edge.data.data == inter_name - consumer_edge.data.data = new_inter_name - consumer_edge.data.replace(memlet_repl) - if is_scalar: - consumer_edge.data.src_subset = "0" - elif consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) - - # The edge that leaves the second map entry was already deleted. - # We will now delete the edges that brought the data. - for edge in state.in_edges_by_connector(map_entry_2, in_conn_name): - assert edge.src == inter_node - state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) - - if is_exclusive_set: - # In exclusive mode the old intermediate node is no longer needed. - assert state.degree(inter_node) == 1 - state.remove_edge_and_connectors(out_edge) - state.remove_node(inter_node) - - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) - del sdfg.arrays[inter_name] - - else: - # This is the shared mode, so we have to recreate the intermediate - # node, but this time it is at the exit of the second map. - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - - # This is the Memlet that goes from the map internal intermediate - # temporary node to the Map output. This will essentially restore - # or preserve the output for the intermediate node. It is important - # that we use the data that `preExitEdge` was used. - new_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert new_exit_memlet.data == inter_name - new_exit_memlet.subset = pre_exit_edge.data.dst_subset - new_exit_memlet.other_subset = ( - "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) - ) - - new_pre_exit_conn = map_exit_2.next_connector() - state.add_edge( - new_inter_node, - None, - map_exit_2, - "IN_" + new_pre_exit_conn, - new_exit_memlet, - ) - state.add_edge( - map_exit_2, - "OUT_" + new_pre_exit_conn, - inter_node, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) - - map_exit_1.remove_out_connector(out_edge.src_conn) - state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py new file mode 100644 index 0000000000..6b7bd1b6d5 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -0,0 +1,1010 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""The GT4Py specific simplification pass.""" + +import collections +import copy +import uuid +from typing import Any, Final, Iterable, Optional, TypeAlias + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import ( + dataflow as dace_dataflow, + pass_pipeline as dace_ppl, + passes as dace_passes, +) + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} +"""Set of simplify passes `gt_simplify()` skips by default. + +The following passes are included: +- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a + symbol or vice versa and at a later point to invert this again. However, this + pass has some problems with this pattern so for the time being it is disabled. +- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. +""" + + +def gt_simplify( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, + skip: Optional[Iterable[str]] = None, +) -> Optional[dict[str, Any]]: + """Performs simplifications on the SDFG in place. + + Instead of calling `sdfg.simplify()` directly, you should use this function, + as it is specially tuned for GridTool based SDFGs. + + This function runs the DaCe simplification pass, but the following passes are + replaced: + - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. + + Further, the function will run the following passes in addition to DaCe simplify: + - `GT4PyGlobalSelfCopyElimination`: Special copy pattern that in the context + of GT4Py based SDFG behaves as a no op. + + Furthermore, by default, or if `None` is passed for `skip` the passes listed in + `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. + + Args: + sdfg: The SDFG to optimize. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + skip: List of simplify passes that should not be applied, defaults + to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. + + Note: + Currently DaCe does not provide a way to inject or exchange sub passes in + simplify. The custom inline pass is run at the beginning and the array + elimination at the end. The whole process is run inside a loop that ensures + that `gt_simplify()` results in a fix point. + """ + # Ensure that `skip` is a `set` + skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) + + result: Optional[dict[str, Any]] = None + + at_least_one_xtrans_run = True + + while at_least_one_xtrans_run: + at_least_one_xtrans_run = False + + if "InlineSDFGs" not in skip: + inline_res = gt_inline_nested_sdfg( + sdfg=sdfg, + multistate=True, + permissive=False, + validate=validate, + validate_all=validate_all, + ) + if inline_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(inline_res) + + simplify_res = dace_passes.SimplifyPass( + validate=validate, + validate_all=validate_all, + verbose=False, + skip=(skip | {"InlineSDFGs"}), + ).apply_pass(sdfg, {}) + + if simplify_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(simplify_res) + + if "GT4PyGlobalSelfCopyElimination" not in skip: + self_copy_removal_result = sdfg.apply_transformations_repeated( + GT4PyGlobalSelfCopyElimination(), + validate=validate, + validate_all=validate_all, + ) + if self_copy_removal_result > 0: + at_least_one_xtrans_run = True + result = result or {} + result.setdefault("GT4PyGlobalSelfCopyElimination", 0) + result["GT4PyGlobalSelfCopyElimination"] += self_copy_removal_result + + return result + + +def gt_inline_nested_sdfg( + sdfg: dace.SDFG, + multistate: bool = True, + permissive: bool = False, + validate: bool = True, + validate_all: bool = False, +) -> Optional[dict[str, int]]: + """Perform inlining of nested SDFG into their parent SDFG. + + The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. + However, before the inline transformation is run the function will run some + cleaning passes that allows inlining nested SDFGs. + As a side effect, the function will split stages into more states. + + Args: + sdfg: The SDFG that should be processed, will be modified in place and returned. + multistate: Allow inlining of multistate nested SDFG, defaults to `True`. + permissive: Be less strict on the accepted SDFGs. + validate: Perform validation after the transformation has finished. + validate_all: Performs extensive validation. + """ + first_iteration = True + nb_preproccess_total = 0 + nb_inlines_total = 0 + while True: + nb_preproccess = sdfg.apply_transformations_repeated( + [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], + validate=False, + validate_all=validate_all, + ) + nb_preproccess_total += nb_preproccess + if (nb_preproccess == 0) and (not first_iteration): + break + + # Create and configure the inline pass + inline_sdfg = dace_passes.InlineSDFGs() + inline_sdfg.progress = False + inline_sdfg.permissive = permissive + inline_sdfg.multistate = multistate + + # Apply the inline pass + # The pass returns `None` no indicate "nothing was done" + nb_inlines = inline_sdfg.apply_pass(sdfg, {}) or 0 + nb_inlines_total += nb_inlines + + # Check result, if needed and test if we can stop + if validate_all or validate: + sdfg.validate() + if nb_inlines == 0: + break + first_iteration = False + + result: dict[str, int] = {} + if nb_inlines_total != 0: + result["InlineSDFGs"] = nb_inlines_total + if nb_preproccess_total != 0: + result["PruneSymbols|PruneConnectors"] = nb_preproccess_total + return result if result else None + + +def gt_substitute_compiletime_symbols( + sdfg: dace.SDFG, + repl: dict[str, Any], + validate: bool = False, + validate_all: bool = False, +) -> None: + """Substitutes symbols that are known at compile time with their value. + + Some symbols are known to have a constant value. This function will remove these + symbols from the SDFG and replace them with the value. + An example where this makes sense are strides that are known to be one. + + Args: + sdfg: The SDFG to process. + repl: Maps the name of the symbol to the value it should be replaced with. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + + # We will use the `replace` function of the top SDFG, however, lower levels + # are handled using ConstantPropagation. + sdfg.replace_dict(repl) + + const_prop = dace_passes.ConstantPropagation() + const_prop.recursive = True + const_prop.progress = False + + const_prop.apply_pass( + sdfg=sdfg, + initial_symbols=repl, + _=None, + ) + gt_simplify( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + dace.sdfg.propagation.propagate_memlets_sdfg(sdfg) + + +def gt_reduce_distributed_buffering( + sdfg: dace.SDFG, +) -> Optional[dict[dace.SDFG, dict[dace.SDFGState, set[str]]]]: + """Removes distributed write back buffers.""" + pipeline = dace_ppl.Pipeline([DistributedBufferRelocator()]) + all_result = {} + + for rsdfg in sdfg.all_sdfgs_recursive(): + ret = pipeline.apply_pass(sdfg, {}) + if ret is not None: + all_result[rsdfg] = ret + + return all_result + + +@dace_properties.make_properties +class GT4PyGlobalSelfCopyElimination(dace_transformation.SingleStateTransformation): + """Remove global self copy. + + This transformation matches the following case `(G) -> (T) -> (G)`, i.e. `G` + is read from and written too at the same time, however, in between is `T` + used as a buffer. In the example above `G` is a global memory and `T` is a + temporary. This situation is generated by the lowering if the data node is + not needed (because the computation on it is only conditional). + + In case `G` refers to global memory rule 3 of ADR-18 guarantees that we can + only have a point wise dependency of the output on the input. + This transformation will remove the write into `G`, i.e. we thus only have + `(G) -> (T)`. The read of `G` and the definition of `T`, will only be removed + if `T` is not used downstream. If it is used `T` will be maintained. + """ + + node_read_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_tmp = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + node_write_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + read_g = self.node_read_g + write_g = self.node_write_g + tmp_node = self.node_tmp + g_desc = read_g.desc(sdfg) + tmp_desc = tmp_node.desc(sdfg) + + # NOTE: We do not check if `G` is read downstream. + if read_g.data != write_g.data: + return False + if g_desc.transient: + return False + if not tmp_desc.transient: + return False + if graph.in_degree(read_g) != 0: + return False + if graph.out_degree(read_g) != 1: + return False + if graph.degree(tmp_node) != 2: + return False + if graph.in_degree(write_g) != 1: + return False + if graph.out_degree(write_g) != 0: + return False + if graph.scope_dict()[read_g] is not None: + return False + + return True + + def _is_read_downstream( + self, + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + ) -> bool: + """Scans for reads to `data_to_look`. + + The function will go through states that are reachable from `start_state` + (including) and test if there is a read to the data container `data_to_look`. + It will return `True` the first time it finds such a node. + It is important that the matched nodes, i.e. `self.node_{read_g, write_g, tmp}` + are ignored. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + + Todo: + Port this function to use DaCe pass pipeline. + """ + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + return gtx_transformations.util.is_accessed_downstream( + start_state=start_state, + sdfg=sdfg, + data_to_look=data_to_look, + nodes_to_ignore={read_g, write_g, tmp_node}, + ) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # We first check if `T`, the intermediate is not used downstream. In this + # case we can remove the read to `G` and `T` itself from the SDFG. + # We have to do this check before, because the matching is not fully stable. + is_tmp_used_downstream = self._is_read_downstream( + start_state=graph, sdfg=sdfg, data_to_look=tmp_node.data + ) + + # The write to `G` can always be removed. + graph.remove_node(write_g) + + # Also remove the read to `G` and `T` from the SDFG if possible. + if not is_tmp_used_downstream: + graph.remove_node(read_g) + graph.remove_node(tmp_node) + # It could still be used in a parallel branch. + try: + sdfg.remove_data(tmp_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): + raise + + +AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode] +"""Describes an access node and the state in which it is located. +""" + + +@dace_properties.make_properties +class DistributedBufferRelocator(dace_transformation.Pass): + """Moves the final write back of the results to where it is needed. + + In certain cases, especially in case where we have `if` the result is computed + in each branch and then in the join state written back. Thus there is some + additional storage needed. + The transformation will look for the following situation: + - A transient data container, called `src_cont`, is written into another + container, called `dst_cont`, which is not transient. + - The access node of `src_cont` has an in degree of zero and an out degree of one. + - The access node of `dst_cont` has an in degree of of one and an + out degree of zero (this might be lifted). + - `src_cont` is not used afterwards. + - `dst_cont` is only used to implement the buffering. + + The function will relocate the writing of `dst_cont` to where `src_cont` is + written, which might be multiple locations. + It will also remove the writing back. + It is advised that after this transformation simplify is run again. + + Note: + Essentially this transformation removes the double buffering of `dst_cont`. + Because we ensure that that `dst_cont` is non transient this is okay, as our + rule guarantees this. + + Todo: + - Allow that `dst_cont` can also be transient. + - Allow that `dst_cont` does not need to be a sink node, this is most + likely most relevant if it is transient. + - Check if `dst_cont` is used between where we want to place it and + where it is currently used. + """ + + def modifies(self) -> dace_ppl.Modifies: + return dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes + + def should_reapply(self, modified: dace_ppl.Modifies) -> bool: + return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes) + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return { + dace_transformation.passes.StateReachability, + dace_transformation.passes.AccessSets, + } + + def apply_pass( + self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] + ) -> Optional[dict[dace.SDFGState, set[str]]]: + reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[ + "StateReachability" + ][sdfg.cfg_id] + access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]] = pipeline_results[ + "AccessSets" + ][sdfg.cfg_id] + result: dict[dace.SDFGState, set[str]] = collections.defaultdict(set) + + to_relocate = self._find_candidates(sdfg, reachable, access_sets) + if len(to_relocate) == 0: + return None + self._relocate_write_backs(sdfg, to_relocate) + + for (wb_an, wb_state), _ in to_relocate: + result[wb_state].add(wb_an.data) + + return result + + def _relocate_write_backs( + self, + sdfg: dace.SDFG, + to_relocate: list[tuple[AccessLocation, list[AccessLocation]]], + ) -> None: + """Perform the actual relocation.""" + for (wb_an, wb_state), def_locations in to_relocate: + # Get the memlet that we have to replicate. + wb_edge = next(iter(wb_state.out_edges(wb_an))) + wb_memlet: dace.Memlet = wb_edge.data + final_dest_name: str = wb_edge.dst.data + + for def_an, def_state in def_locations: + def_state.add_edge( + def_an, + wb_edge.src_conn, + def_state.add_access(final_dest_name), + wb_edge.dst_conn, + copy.deepcopy(wb_memlet), + ) + + # Now remove the old node and if the old target become isolated + # remove that as well. + old_dst = wb_edge.dst + wb_state.remove_node(wb_an) + if wb_state.degree(old_dst) == 0: + wb_state.remove_node(old_dst) + + def _find_candidates( + self, + sdfg: dace.SDFG, + reachable: dict[dace.SDFGState, set[dace.SDFGState]], + access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]], + ) -> list[tuple[AccessLocation, list[AccessLocation]]]: + """Determines all temporaries that have to be relocated. + + Returns: + A list of tuples. The first element element of the tuple is an + `AccessLocation` that describes where the temporary is read. + The second element is a list of `AccessLocation`s that describes + where the temporary is defined. + """ + # All nodes that are used as distributed buffers. + candidate_src_cont: list[AccessLocation] = [] + + # Which `src_cont` access node is written back to which global memory. + src_cont_to_global: dict[dace_nodes.AccessNode, str] = {} + + for state in sdfg.states(): + # These are the possible targets we want to write into. + candidate_dst_nodes: set[dace_nodes.AccessNode] = { + node + for node in state.sink_nodes() + if ( + isinstance(node, dace_nodes.AccessNode) + and state.in_degree(node) == 1 + and (not node.desc(sdfg).transient) + ) + } + if len(candidate_dst_nodes) == 0: + continue + + for src_cont in state.source_nodes(): + if not isinstance(src_cont, dace_nodes.AccessNode): + continue + if not src_cont.desc(sdfg).transient: + continue + if state.out_degree(src_cont) != 1: + continue + dst_candidate: dace_nodes.AccessNode = next( + iter(edge.dst for edge in state.out_edges(src_cont)) + ) + if dst_candidate not in candidate_dst_nodes: + continue + candidate_src_cont.append((src_cont, state)) + src_cont_to_global[src_cont] = dst_candidate.data + + if len(candidate_src_cont) == 0: + return [] + + # Now we have to find the places where the temporary sources are defined. + # I.e. This is also the location where the original value is defined. + result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: + return { + src_state + for src_state in sdfg.states() + if dst_state in reachable[src_state] and dst_state is not src_state + } + + for src_cont in candidate_src_cont: + def_locations: list[AccessLocation] = [] + for upstream_state in find_upstream_states(src_cont[1]): + if src_cont[0].data in access_sets[upstream_state][1]: + def_locations.extend( + (data_node, upstream_state) + for data_node in upstream_state.data_nodes() + if data_node.data == src_cont[0].data + ) + if len(def_locations) != 0: + result_candidates.append((src_cont, def_locations)) + + # This transformation removes `src_cont` by writing its content directly + # to `dst_cont`, at the point where it is defined. + # For this transformation to be valid the following conditions have to be met: + # - Between the definition of `src_cont` and the write back to `dst_cont`, + # `dst_cont` can not be accessed. + # - Between the definitions of `src_cont` and the point where it is written + # back, `src_cont` can only be accessed in the range that is written back. + # - After the write back point, `src_cont` shall not be accessed. This + # restriction could be lifted. + # + # To keep the implementation simple, we use the conditions: + # - `src_cont` is only accessed were it is defined and at the write back + # point. + # - Between the definitions of `src_cont` and the write back point, + # `dst_cont` is not used. + + result: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + for wb_localation, def_locations in result_candidates: + for def_node, def_state in def_locations: + # Test if `src_cont` is only accessed where it is defined and + # where it is written back. + if gtx_transformations.util.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + data_to_look=wb_localation[0].data, + nodes_to_ignore={def_node, wb_localation[0]}, + ): + break + # check if the global data is not used between the definition of + # `dst_cont` and where its written back. We allow one exception, + # if the global data is used in the state the distributed temporary + # is defined is used only for reading then it is ignored. This is + # allowed because of rule 3 of ADR0018. + glob_nodes_in_def_state = { + dnode + for dnode in def_state.data_nodes() + if dnode.data == src_cont_to_global[wb_localation[0]] + } + if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state): + break + if gtx_transformations.util.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + data_to_look=src_cont_to_global[wb_localation[0]], + nodes_to_ignore=glob_nodes_in_def_state, + states_to_ignore={wb_localation[1]}, + ): + break + else: + result.append((wb_localation, def_locations)) + + return result + + +@dace_properties.make_properties +class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): + """Moves a Tasklet, with no input into a map. + + Tasklets without inputs, are mostly used to generate constants. + However, if they are outside a Map, then this constant value is an + argument to the kernel, and can not be used by the compiler. + + This transformation moves such Tasklets into a Map scope. + """ + + tasklet = dace_transformation.PatternNode(dace_nodes.Tasklet) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.tasklet, cls.access_node, cls.map_entry)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Data = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + if graph.in_degree(tasklet) != 0: + return False + if graph.out_degree(tasklet) != 1: + return False + if tasklet.has_side_effects(sdfg): + return False + if tasklet.code_init.as_string: + return False + if tasklet.code_exit.as_string: + return False + if tasklet.code_global.as_string: + return False + if tasklet.state_fields: + return False + if not isinstance(access_desc, dace_data.Scalar): + return False + if not access_desc.transient: + return False + if not any( + edge.dst_conn and edge.dst_conn.startswith("IN_") + for edge in graph.out_edges(access_node) + if edge.dst is map_entry + ): + return False + # NOTE: We allow that the access node is used in multiple places. + + return True + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Scalar = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + # Find _a_ connection that leads from the access node to the map. + edge_to_map = next( + iter( + edge + for edge in graph.out_edges(access_node) + if edge.dst is map_entry and edge.dst_conn.startswith("IN_") + ) + ) + connector_name: str = edge_to_map.dst_conn[3:] + + # This is the tasklet that we will put inside the map, note we have to do it + # this way to avoid some name clash stuff. + inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet( + name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}", + outputs=tasklet.out_connectors.keys(), + inputs=set(), + code=tasklet.code, + language=tasklet.language, + debuginfo=tasklet.debuginfo, + ) + inner_desc: dace_data.Scalar = access_desc.clone() + inner_data_name: str = sdfg.add_datadesc(access_node.data, inner_desc, find_new_name=True) + inner_an: dace_nodes.AccessNode = graph.add_access(inner_data_name) + + # Connect the tasklet with the map entry and the access node. + graph.add_nedge(map_entry, inner_tasklet, dace.Memlet()) + graph.add_edge( + inner_tasklet, + next(iter(inner_tasklet.out_connectors.keys())), + inner_an, + None, + dace.Memlet(f"{inner_data_name}[0]"), + ) + + # Now we will reroute the edges went through the inner map, through the + # inner access node instead. + for old_inner_edge in list( + graph.out_edges_by_connector(map_entry, "OUT_" + connector_name) + ): + # We now modify the downstream data. This is because we no longer refer + # to the data outside but the one inside. + self._modify_downstream_memlets( + state=graph, + edge=old_inner_edge, + old_data=access_node.data, + new_data=inner_data_name, + ) + + # After we have changed the properties of the MemletTree of `edge` + # we will now reroute it, such that the inner access node is used. + graph.add_edge( + inner_an, + None, + old_inner_edge.dst, + old_inner_edge.dst_conn, + old_inner_edge.data, + ) + graph.remove_edge(old_inner_edge) + map_entry.remove_in_connector("IN_" + connector_name) + map_entry.remove_out_connector("OUT_" + connector_name) + + # Now we can remove the map connection between the outer/old access + # node and the map. + graph.remove_edge(edge_to_map) + + # The data is no longer referenced in this state, so we can potentially + # remove + if graph.out_degree(access_node) == 0: + if not gtx_transformations.util.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + data_to_look=access_node.data, + nodes_to_ignore={access_node}, + ): + graph.remove_nodes_from([tasklet, access_node]) + # Needed if data is accessed in a parallel branch. + try: + sdfg.remove_data(access_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {access_node.data}:"): + raise + + def _modify_downstream_memlets( + self, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge, + old_data: str, + new_data: str, + ) -> None: + """Replaces the data along on the tree defined by `edge`. + + The function will traverse the MemletTree defined by `edge`. + Any Memlet that refers to `old_data` will be replaced with + `new_data`. + + Args: + state: The sate in which we operate. + edge: The edge defining the MemletTree. + old_data: The name of the data that should be replaced. + new_data: The name of the new data the Memlet should refer to. + """ + mtree: dace.memlet.MemletTree = state.memlet_tree(edge) + for tedge in mtree.traverse_children(True): + # Because we only change the name of the data, we do not change the + # direction of the Memlet, so `{src, dst}_subset` will remain the same. + if tedge.edge.data.data == old_data: + tedge.edge.data.data = new_data + + +@dace_properties.make_properties +class GT4PyMapBufferElimination(dace_transformation.SingleStateTransformation): + """Allows to remove unneeded buffering at map output. + + The transformation matches the case `MapExit -> (T) -> (G)`, where `T` is an + AccessNode referring to a transient and `G` an AccessNode that refers to non + transient memory. + If the following conditions are met then `T` is removed. + - `T` is not used to filter computations, i.e. what is written into `G` + is covered by what is written into `T`. + - `T` is not used anywhere else. + - `G` is not also an input to the map, except there is only a pointwise + dependency in `G`, see the note below. + - Everything needs to be at top scope. + + Notes: + - Rule 3 of ADR18 should guarantee that any valid GT4Py program meets the + point wise dependency in `G`, for that reason it is possible to disable + this test by specifying `assume_pointwise`. + + Todo: + - Implement a real pointwise test. + """ + + map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + tmp_ac = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + glob_ac = dace_transformation.PatternNode(dace_nodes.AccessNode) + + assume_pointwise = dace_properties.Property( + dtype=bool, + default=False, + desc="Dimensions that should become the leading dimension.", + ) + + def __init__( + self, + assume_pointwise: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if assume_pointwise is not None: + self.assume_pointwise = assume_pointwise + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit, cls.tmp_ac, cls.glob_ac)] + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return {dace_transformation.passes.ConsolidateEdges} + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + glob_ac: dace_nodes.AccessNode = self.glob_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + glob_desc: dace_data.Data = glob_ac.desc(sdfg) + + if not tmp_desc.transient: + return False + if glob_desc.transient: + return False + if graph.in_degree(tmp_ac) != 1: + return False + if any(gtx_transformations.util.is_view(ac, sdfg) for ac in [tmp_ac, glob_ac]): + return False + if len(glob_desc.shape) != len(tmp_desc.shape): + return False + + # Test if we are on the top scope (it is likely). + if graph.scope_dict()[glob_ac] is not None: + return False + + # Now perform if we are point wise + if not self._perform_pointwise_test(graph, sdfg): + return False + + # Test if `tmp` is only anywhere else, this is important for removing it. + if graph.out_degree(tmp_ac) != 1: + return False + if gtx_transformations.util.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + data_to_look=tmp_ac.data, + nodes_to_ignore={tmp_ac}, + ): + return False + + # Now we ensure that `tmp` is not used to filter out some computations. + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + tmp_in_subset = map_to_tmp_edge.data.get_dst_subset(map_to_tmp_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + if tmp_in_subset is None: + tmp_in_subset = dace_subsets.Range.from_array(tmp_desc) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + if glob_in_subset is None: + return False + + # TODO(phimuell): Do we need simplify in the check. + # TODO(phimuell): Restrict this to having the same size. + if tmp_out_subset != tmp_in_subset: + return False + return True + + def _perform_pointwise_test( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Test if `G` is only point wise accessed. + + This function will also consider the `assume_pointwise` property. + """ + map_exit: dace_nodes.MapExit = self.map_exit + map_entry: dace_nodes.MapEntry = state.entry_node(map_exit) + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data: str = glob_ac.data + + # First we check if `G` is also an input to this map. + conflicting_inputs: set[dace_nodes.AccessNode] = set() + for in_edge in state.in_edges(map_entry): + if not isinstance(in_edge.src, dace_nodes.AccessNode): + continue + + # Find the source of this data, if it is a view we trace it to + # its origin. + src_node: dace_nodes.AccessNode = gtx_transformations.util.track_view( + in_edge.src, state, sdfg + ) + + # Test if there is a conflict; We do not store the source but the + # actual node that is adjacent. + if src_node.data == glob_data: + conflicting_inputs.add(in_edge.src) + + # If there are no conflicting inputs, then we are point wise. + # This is an implementation detail that make life simpler. + if len(conflicting_inputs) == 0: + return True + + # If we can assume pointwise computations, then we do not have to do + # anything. + if self.assume_pointwise: + return True + + # Currently the only test that we do is, if we have a view, then we + # are not point wise. + # TODO(phimuell): Improve/implement this. + return any(gtx_transformations.util.is_view(node, sdfg) for node in conflicting_inputs) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + # Removal + # Propagation ofthe shift. + map_exit: dace_nodes.MapExit = self.map_exit + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + tmp_data = tmp_ac.data + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data = glob_ac.data + + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + assert glob_in_subset is not None + + # We now remove the `tmp` node, and create a new connection between + # the global node and the map exit. + new_map_to_glob_edge = graph.add_edge( + map_exit, + map_to_tmp_edge.src_conn, + glob_ac, + tmp_to_glob_edge.dst_conn, + dace.Memlet( + data=glob_ac.data, + subset=copy.deepcopy(glob_in_subset), + ), + ) + graph.remove_edge(map_to_tmp_edge) + graph.remove_edge(tmp_to_glob_edge) + graph.remove_node(tmp_ac) + + # We can not unconditionally remove the data `tmp` refers to, because + # it could be that in a parallel branch the `tmp` is also defined. + try: + sdfg.remove_data(tmp_ac.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_ac.data}:"): + raise + + # Now we must modify the memlets inside the map scope, because + # they now write into `G` instead of `tmp`, which has a different + # offset. + # NOTE: Assumes that `tmp_out_subset` and `tmp_in_subset` are the same. + correcting_offset = glob_in_subset.offset_new(tmp_out_subset, negative=True) + mtree = graph.memlet_tree(new_map_to_glob_edge) + for tree in mtree.traverse_children(include_self=False): + curr_edge = tree.edge + curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, graph) + if curr_edge.data.data == tmp_data: + curr_edge.data.data = glob_data + if curr_dst_subset is not None: + curr_dst_subset.offset(correcting_offset, negative=False) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py new file mode 100644 index 0000000000..4e254f2880 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -0,0 +1,99 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def gt_change_transient_strides( + sdfg: dace.SDFG, + gpu: bool, +) -> dace.SDFG: + """Modifies the strides of transients. + + The function will analyse the access patterns and set the strides of + transients in the optimal way. + The function should run after all maps have been created. + + Args: + sdfg: The SDFG to process. + gpu: If the SDFG is supposed to run on the GPU. + + Note: + Currently the function will not scan the access pattern. Instead it will + either use FORTRAN order for GPU or C order (which is assumed to be the + default, so it is a no ops). + + Todo: + - Implement the estimation correctly. + - Handle the case of nested SDFGs correctly; on the outside a transient, + but on the inside a non transient. + """ + # TODO(phimeull): Implement this function correctly. + + # We assume that by default we have C order which is already correct, + # so in this case we have a no ops + if not gpu: + return sdfg + + for nsdfg in sdfg.all_sdfgs_recursive(): + # TODO(phimuell): Handle the case when transient goes into nested SDFG + # on the inside it is a non transient, so it is ignored. + _gt_change_transient_strides_non_recursive_impl(nsdfg) + + +def _gt_change_transient_strides_non_recursive_impl( + sdfg: dace.SDFG, +) -> None: + """Essentially this function just changes the stride to FORTRAN order.""" + for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + desc: dace_data.Array = sdfg.arrays[top_level_transient] + ndim = len(desc.shape) + if ndim <= 1: + continue + # We assume that everything is in C order initially, to get FORTRAN order + # we simply have to reverse the order. + new_stride_order = list(range(ndim)) + desc.set_strides_from_layout(*new_stride_order) + + +def _find_toplevel_transients( + sdfg: dace.SDFG, + only_arrays: bool = False, +) -> set[str]: + """Find all top level transients in the SDFG. + + The function will scan the SDFG, ignoring nested one, and return the + name of all transients that have an access node at the top level. + However, it will ignore access nodes that refers to registers. + """ + top_level_transients: set[str] = set() + for state in sdfg.states(): + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + data: str = dnode.data + if scope_dict[dnode] is not None: + if data in top_level_transients: + top_level_transients.remove(data) + continue + elif data in top_level_transients: + continue + elif gtx_transformations.util.is_view(dnode, sdfg): + continue + desc: dace_data.Data = dnode.desc(sdfg) + + if not desc.transient: + continue + elif only_arrays and not isinstance(desc, dace_data.Array): + continue + top_level_transients.add(data) + return top_level_transients diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index 29bae7bbe0..29c099eecf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -8,153 +8,220 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Iterable, Union +from typing import Any, Container, Optional, Union import dace -from dace.sdfg import graph as dace_graph, nodes as dace_nodes +from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes +from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -def is_nested_sdfg( - sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], -) -> bool: - """Tests if `sdfg` is a NestedSDFG.""" - if isinstance(sdfg, dace.SDFGState): - sdfg = sdfg.parent - if isinstance(sdfg, dace_nodes.NestedSDFG): - return True - elif isinstance(sdfg, dace.SDFG): - return sdfg.parent_nsdfg_node is not None - raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") - - -def all_nodes_between( - graph: dace.SDFG | dace.SDFGState, - begin: dace_nodes.Node, - end: dace_nodes.Node, - reverse: bool = False, -) -> set[dace_nodes.Node] | None: - """Find all nodes that are reachable from `begin` but bound by `end`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end`, this edge is ignored. It will thus found any node that is reachable - from `begin` by a path that does not involve `end`. The returned set will - never contain `end` nor `begin`. In case `end` is never found the function - will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + +def gt_make_transients_persistent( + sdfg: dace.SDFG, + device: dace.DeviceType, +) -> dict[int, set[str]]: + """ + Changes the lifetime of certain transients to `Persistent`. + + A persistent lifetime means that the transient is allocated only the very first + time the SDFG is executed and only deallocated if the underlying `CompiledSDFG` + object goes out of scope. The main advantage is, that memory must not be + allocated every time the SDFG is run. The downside is that the SDFG can not be + called by different threads. Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. + sdfg: The SDFG to process. + device: The device type. + + Returns: + A `dict` mapping SDFG IDs to a set of transient arrays that + were made persistent. + + Note: + This function is based on a similar function in DaCe. However, the DaCe + function does, for unknown reasons, also reset the `wcr_nonatomic` property, + but only for GPU. """ + result: dict[int, set[str]] = {} + for nsdfg in sdfg.all_sdfgs_recursive(): + fsyms: set[str] = nsdfg.free_symbols + modify_lifetime: set[str] = set() + not_modify_lifetime: set[str] = set() + + for state in nsdfg.states(): + for dnode in state.data_nodes(): + if dnode.data in not_modify_lifetime: + continue - def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: - return ( - (edge.src for edge in graph.in_edges(node)) - if reverse - else (edge.dst for edge in graph.out_edges(node)) - ) + if dnode.data in nsdfg.constants_prop: + not_modify_lifetime.add(dnode.data) + continue - if reverse: - begin, end = end, begin + desc = dnode.desc(nsdfg) + if not desc.transient or type(desc) not in {dace.data.Array, dace.data.Scalar}: + not_modify_lifetime.add(dnode.data) + continue + if desc.storage == dace.StorageType.Register: + not_modify_lifetime.add(dnode.data) + continue - to_visit: list[dace_nodes.Node] = [begin] - seen: set[dace_nodes.Node] = set() + if desc.lifetime == dace.AllocationLifetime.External: + not_modify_lifetime.add(dnode.data) + continue - while len(to_visit) > 0: - node: dace_nodes.Node = to_visit.pop() - if node != end and node not in seen: - to_visit.extend(next_nodes(node)) - seen.add(node) + try: + # The symbols describing the total size must be a subset of the + # free symbols of the SDFG (symbols passed as argument). + # NOTE: This ignores the renaming of symbols through the + # `symbol_mapping` property of nested SDFGs. + if not set(map(str, desc.total_size.free_symbols)).issubset(fsyms): + not_modify_lifetime.add(dnode.data) + continue + except AttributeError: # total_size is an integer / has no free symbols + pass - # If `end` was not found we have to return `None` to indicate this. - if end not in seen: - return None + # Make it persistent. + modify_lifetime.add(dnode.data) - # `begin` and `end` are not included in the output set. - return seen - {begin, end} + # Now setting the lifetime. + result[nsdfg.cfg_id] = modify_lifetime - not_modify_lifetime + for aname in result[nsdfg.cfg_id]: + nsdfg.arrays[aname].lifetime = dace.AllocationLifetime.Persistent + return result -def find_downstream_consumers( - state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, - reverse: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Find all downstream connectors of `begin`. - - A consumer, in for this function, is any node that is neither an entry nor - an exit node. The function returns a set of pairs, the first element is the - node that acts as consumer and the second is the edge that leads to it. - By setting `only_tasklets` the nodes the function finds are only Tasklets. - - To find this set the function starts a search at `begin`, however, it is also - possible to pass an edge as `begin`. - If `reverse` is `True` the function essentially finds the producers that are - upstream. + +def gt_find_constant_arguments( + call_args: dict[str, Any], + include: Optional[Container[str]] = None, +) -> dict[str, Any]: + """Scans the calling arguments for compile time constants. + + The output of this function can be used as input to + `gt_substitute_compiletime_symbols()`, which then removes these symbols. + + By specifying `include` it is possible to force the function to include + additional arguments, that would not be matched otherwise. Importantly, + their value is not checked. Args: - state: The state in which to look for the consumers. - begin: The initial node that from which the search starts. - only_tasklets: Return only Tasklets. - reverse: Follow the reverse direction. + call_args: The full list of arguments that will be passed to the SDFG. + include: List of arguments that should be included. """ - if isinstance(begin, dace_graph.MultiConnectorEdge): - to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] - else: - to_visit = state.in_edges(begin) if reverse else state.out_edges(begin) + if include is None: + include = set() + ret_value: dict[str, Any] = {} - seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + for name, value in call_args.items(): + if name in include or (dace_utils.is_field_symbol(name) and value == 1): + ret_value[name] = value - while len(to_visit) > 0: - curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst - - if curr_edge in seen: - continue - seen.add(curr_edge) - - if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): - if not reverse: - # In forward mode a Map entry could also mean the definition of a - # dynamic map range. - if isinstance(next_node, dace_nodes.MapEntry) and ( - not curr_edge.dst_conn.startswith("IN_") - ): - if not only_tasklets: - found.add((next_node, curr_edge)) - continue - target_conn = curr_edge.dst_conn[3:] - new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) - else: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - to_visit.extend(new_edges) + return ret_value - elif isinstance(next_node, dace_nodes.Tasklet) or not only_tasklets: - # We have found a consumer. - found.add((next_node, curr_edge)) - return found +def is_accessed_downstream( + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + nodes_to_ignore: Optional[set[dace_nodes.AccessNode]] = None, + states_to_ignore: Optional[set[dace.SDFGState]] = None, +) -> bool: + """Scans for accesses to the data container `data_to_look`. + The function will go through states that are reachable from `start_state` + (included) and test if there is an AccessNode that refers to `data_to_look`. + It will return `True` the first time it finds such a node. -def find_upstream_producers( + The function will ignore all nodes that are listed in `nodes_to_ignore`. + Furthermore, states listed in `states_to_ignore` will be ignored, i.e. + handled as they did not exist. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + nodes_to_ignore: Ignore these nodes. + states_to_ignore: Ignore these states. + """ + seen_states: set[dace.SDFGState] = set() + to_visit: list[dace.SDFGState] = [start_state] + ign_dnodes: set[dace_nodes.AccessNode] = nodes_to_ignore or set() + ign_states: set[dace.SDFGState] = states_to_ignore or set() + + while len(to_visit) > 0: + state = to_visit.pop() + seen_states.add(state) + for dnode in state.data_nodes(): + if dnode.data != data_to_look: + continue + if dnode in ign_dnodes: + continue + if state.out_degree(dnode) != 0: + return True # There is a read operation + + # Look for new states, also scan the interstate edges. + for out_edge in sdfg.out_edges(state): + if out_edge.dst in ign_states: + continue + if data_to_look in out_edge.data.read_symbols(): + return True + if out_edge.dst in seen_states: + continue + to_visit.append(out_edge.dst) + + return False + + +def is_view( + node: Union[dace_nodes.AccessNode, dace_data.Data], + sdfg: dace.SDFG, +) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: dace_data.Data = node.desc(sdfg) if isinstance(node, dace_nodes.AccessNode) else node + return isinstance(node_desc, dace_data.View) + + +def track_view( + view: dace_nodes.AccessNode, state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" - return find_downstream_consumers( - state=state, - begin=begin, - only_tasklets=only_tasklets, - reverse=True, - ) + sdfg: dace.SDFG, +) -> dace_nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") + + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 80b8f4f39b..d7413f32d7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -169,14 +169,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST - + [ - (ALL, SKIP, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): Enable once the optimization pipeline is merged - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [ - (ALL, SKIP, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): Enable once the optimization pipeline is merged. + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py index e85ef6ad1f..0eb0bf39c2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py @@ -11,7 +11,7 @@ import pytest -@pytest.fixture() +@pytest.fixture(autouse=True) def set_dace_settings() -> Generator[None, None, None]: """Sets the common DaCe settings for the tests. @@ -24,6 +24,6 @@ def set_dace_settings() -> Generator[None, None, None]: import dace with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=False) + dace.Config.set("optimizer", "match_exception", value=True) dace.Config.set("compiler", "allow_view_arguments", value=True) yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py new file mode 100644 index 0000000000..04a4f098ef --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -0,0 +1,142 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + + +def test_constant_substitution(): + sdfg, nsdfg = _make_sdfg() + + # Ensure that `One` is present. + assert len(sdfg.symbols) == 2 + assert len(nsdfg.sdfg.symbols) == 2 + assert len(nsdfg.symbol_mapping) == 2 + assert "One" in sdfg.symbols + assert "One" in nsdfg.sdfg.symbols + assert "One" in nsdfg.symbol_mapping + assert "One" == str(nsdfg.symbol_mapping["One"]) + assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" in sdfg.used_symbols(True) + + # Now replace `One` with 1 + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1}) + + assert len(sdfg.symbols) == 1 + assert len(nsdfg.sdfg.symbols) == 1 + assert len(nsdfg.symbol_mapping) == 1 + assert "One" not in sdfg.symbols + assert "One" not in nsdfg.sdfg.symbols + assert "One" not in nsdfg.symbol_mapping + assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) + assert all( + desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() + ) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" not in sdfg.used_symbols(True) + + +def _make_nested_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("nested") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABC": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + state = sdfg.add_state(is_start_block=True) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("B[__i0, __i1]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: + sdfg = dace.SDFG("outer_sdfg") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABCD": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + sdfg.arrays["C"].transient = True + + first_state: dace.SDFGState = sdfg.add_state(is_start_block=True) + nested_sdfg: dace.SDFG = _make_nested_sdfg() + nsdfg = first_state.add_nested_sdfg( + nested_sdfg, + parent=sdfg, + inputs={"A", "B"}, + outputs={"C"}, + symbol_mapping={"One": "One", "N": "N"}, + ) + first_state.add_edge( + first_state.add_access("A"), + None, + nsdfg, + "A", + dace.Memlet("A[0:N, 0:N]"), + ) + first_state.add_edge( + first_state.add_access("B"), + None, + nsdfg, + "B", + dace.Memlet("B[0:N, 0:N]"), + ) + first_state.add_edge( + nsdfg, + "C", + first_state.add_access("C"), + None, + dace.Memlet("C[0:N, 0:N]"), + ) + + second_state: dace.SDFGState = sdfg.add_state_after(first_state) + second_state.add_mapped_tasklet( + "outer_computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("C[__i0, __i1]"), + }, + code="__out = __in0 * __in1", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg, nsdfg diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py new file mode 100644 index 0000000000..3d9201c603 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -0,0 +1,239 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _create_sdfg_double_read_part_1( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_1", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 1.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("A[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("A"), None, dace.Memlet("A[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read_part_2( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_2", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 3.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("B[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("B"), None, dace.Memlet("B[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read( + version: int, +) -> tuple[dace.SDFG]: + sdfg = dace.SDFG(util.unique_name(f"double_read_version_{version}")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in = state.add_access("A") + me, mx = state.add_map("map", ndrange={"__i0": "0:10"}) + + if version == 0: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 0) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 1) + elif version == 1: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 1) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 0) + else: + raise ValueError(f"Does not know version {version}") + sdfg.validate() + return sdfg + + +def test_local_double_buffering_double_read_sdfg(): + sdfg0 = _create_sdfg_double_read(0) + sdfg1 = _create_sdfg_double_read(1) + args0 = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + args1 = copy.deepcopy(args0) + + count0 = gtx_transformations.gt_create_local_double_buffering(sdfg0) + assert count0 == 1 + + count1 = gtx_transformations.gt_create_local_double_buffering(sdfg1) + assert count1 == 1 + + sdfg0(**args0) + sdfg1(**args1) + for name in args0: + assert np.allclose(args0[name], args1[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_connection(): + """There is no direct connection between read and write.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_connection")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in, B, A_out = (state.add_access(name) for name in "ABA") + + comp_tskl, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A_in}, + output_nodes={B}, + external_edges=True, + ) + + fill_tasklet = state.add_tasklet( + name="fill_tasklet", + inputs=set(), + code="__out = 2.", + outputs={"__out"}, + ) + state.add_nedge(me, fill_tasklet, dace.Memlet()) + state.add_edge(fill_tasklet, "__out", mx, "IN_1", dace.Memlet("A[__i0]")) + state.add_edge(mx, "OUT_1", A_out, None, dace.Memlet("A[0:10]")) + mx.add_in_connector("IN_1") + mx.add_out_connector("OUT_1") + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 1 + + # Ensure that a second application of the transformation does not run again. + count_again = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count_again == 0 + + # Find the newly created access node. + comp_tasklet_producers = [in_edge.src for in_edge in state.in_edges(comp_tskl)] + assert len(comp_tasklet_producers) == 1 + new_double_buffer = comp_tasklet_producers[0] + assert isinstance(new_double_buffer, dace_nodes.AccessNode) + assert not any(new_double_buffer.data == name for name in "AB") + assert isinstance(new_double_buffer.desc(sdfg), dace_data.Scalar) + assert new_double_buffer.desc(sdfg).transient + + # The newly created access node, must have an empty Memlet to the fill tasklet. + read_dependencies = [ + out_edge.dst for out_edge in state.out_edges(new_double_buffer) if out_edge.data.is_empty() + ] + assert len(read_dependencies) == 1 + assert read_dependencies[0] is fill_tasklet + + res = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + ref = {"A": np.full_like(res["A"], 2.0), "B": res["A"] + 10.0} + sdfg(**res) + for name in res: + assert np.allclose(res[name], ref[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_apply(): + """Here it does not apply, because are all distinct.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + external_edges=True, + ) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 + + +def test_local_double_buffering_already_buffered(): + """It is already buffered.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_array( + "A", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + tsklt, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("A[__i0]")}, + external_edges=True, + ) + + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + me_to_tskl_edge = next(iter(state.out_edges(me))) + + state.add_edge(me, me_to_tskl_edge.src_conn, tmp, None, dace.Memlet("A[__i0]")) + state.add_edge(tmp, None, tsklt, "__in1", dace.Memlet("tmp[0]")) + state.remove_edge(me_to_tskl_edge) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py new file mode 100644 index 0000000000..1543a048ad --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +# from . import util + + +# dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +import dace + + +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG("NAME") # util.unique_name("distributed_buffer_sdfg")) + + for name in ["a", "b", "tmp"]: + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + sdfg.arrays["tmp"].transient = True + sdfg.arrays["b"].shape = (100, 100) + + state1: dace.SDFGState = sdfg.add_state(is_start_block=True) + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i1": "0:10", "__i2": "0:10"}, + inputs={"__in": dace.Memlet("a[__i1, __i2]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("tmp[__i1, __i2]")}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state1) + state2_tskl = state2.add_tasklet( + name="empty_blocker_tasklet", + inputs={}, + code="pass", + outputs={"__out"}, + side_effects=True, + ) + state2.add_edge( + state2_tskl, + "__out", + state2.add_access("a"), + None, + dace.Memlet("a[0, 0]"), + ) + + state3 = sdfg.add_state_after(state2) + state3.add_edge( + state3.add_access("tmp"), + None, + state3.add_access("b"), + None, + dace.Memlet("tmp[0:10, 0:10] -> [11:21, 22:32]"), + ) + sdfg.validate() + assert sdfg.number_of_nodes() == 3 + + return sdfg, state1 + + +def test_distributed_buffer_remover(): + sdfg, state1 = _mk_distributed_buffer_sdfg() + assert state1.number_of_nodes() == 5 + assert not any(dnode.data == "b" for dnode in state1.data_nodes()) + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res is not None + + # Because the final state has now become empty + assert sdfg.number_of_nodes() == 3 + assert state1.number_of_nodes() == 6 + assert any(dnode.data == "b" for dnode in state1.data_nodes()) + assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py new file mode 100644 index 0000000000..4ca44d43eb --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py @@ -0,0 +1,148 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + """Generates an SDFG that contains the self copying pattern.""" + sdfg = dace.SDFG(util.unique_name("self_copy_sdfg")) + state = sdfg.add_state(is_start_block=True) + + for name in "GT": + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["G"].transient = False + g_read, tmp_node, g_write = (state.add_access(name) for name in "GTG") + + state.add_nedge(g_read, tmp_node, dace.Memlet("G[0:10, 0:10]")) + state.add_nedge(tmp_node, g_write, dace.Memlet("G[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state + + +def test_global_self_copy_elimination_only_pattern(): + """Contains only the pattern -> Total elimination.""" + sdfg, state = _make_self_copy_sdfg() + assert sdfg.number_of_nodes() == 1 + assert state.number_of_nodes() == 3 + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert state.number_of_edges() == 2 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 1 + assert ( + state.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + + +def test_global_self_copy_elimination_g_downstream(): + """`G` is read downstream. + + Since we ignore reads to `G` downstream, this will not influence the + transformation. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("G[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert ( + state1.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 + + +def test_global_self_copy_elimination_tmp_downstream(): + """`T` is read downstream. + + Because `T` is read downstream, the read to `G` will be retained, but the write + will be removed. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("T[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert state1.number_of_nodes() == 2 + assert util.count_nodes(state1, dace_nodes.AccessNode) == 2 + assert all(state1.degree(node) == 1 for node in state1.nodes()) + assert next(iter(state1.source_nodes())).data == "G" + assert next(iter(state1.sink_nodes())).data == "T" + + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 30266d71d1..89f067e5a9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -24,11 +24,12 @@ def _get_trivial_gpu_promotable( tasklet_code: str, + trivial_map_range: str = "0", ) -> tuple[dace.SDFG, dace_nodes.MapEntry, dace_nodes.MapEntry]: - """Returns an SDFG that is suitable to test the `TrivialGPUMapPromoter` promoter. + """Returns an SDFG that is suitable to test the `TrivialGPUMapElimination` promoter. The first map is a trivial map (`Map[__trival_gpu_it=0]`) containing a Tasklet, - that does not have an output, but writes a scalar value into `tmp` (output + that does not have an input, but writes a scalar value into `tmp` (output connector `__out`), the body of this Tasklet can be controlled through the `tasklet_code` argument. The second map (`Map[__i0=0:N]`) contains a Tasklet that computes the sum of its @@ -41,6 +42,7 @@ def _get_trivial_gpu_promotable( Args: tasklet_code: The body of the Tasklet inside the trivial map. + trivial_map_range: Range of the trivial map, defaults to `"0"`. """ sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -57,11 +59,11 @@ def _get_trivial_gpu_promotable( _, trivial_map_entry, _ = state.add_mapped_tasklet( "trivail_top_tasklet", - map_ranges={"__trivial_gpu_it": "0"}, + map_ranges={"__trivial_gpu_it": trivial_map_range}, inputs={}, code=tasklet_code, outputs={"__out": dace.Memlet("tmp[0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, schedule=schedule, ) @@ -74,15 +76,15 @@ def _get_trivial_gpu_promotable( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("b[__i0]")}, - input_nodes={"a": a, "tmp": tmp}, - output_nodes={"b": b}, + input_nodes={a, tmp}, + output_nodes={b}, external_edges=True, schedule=schedule, ) return sdfg, trivial_map_entry, second_map_entry -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_1(): """Tests if the GPU map promoter works. By using a body such as `__out = 3.0`, the transformation will apply. @@ -92,15 +94,15 @@ def test_trivial_gpu_map_promoter(): org_second_map_ranges = copy.deepcopy(second_map_entry.map.range) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) assert ( nb_runs == 1 - ), f"Expected that 'TrivialGPUMapPromoter' applies once but it applied {nb_runs}." + ), f"Expected that 'TrivialGPUMapElimination' applies once but it applied {nb_runs}." trivial_map_params = trivial_map_entry.map.params - trivial_map_ranges = trivial_map_ranges.map.range + trivial_map_ranges = trivial_map_entry.map.range second_map_params = second_map_entry.map.params second_map_ranges = second_map_entry.map.range @@ -119,32 +121,82 @@ def test_trivial_gpu_map_promoter(): assert sdfg.is_valid() -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_2(): """Test if the GPU promoter does not fuse a special trivial map. By using a body such as `__out = __trivial_gpu_it` inside the - Tasklet's body, the map parameter is now used, and thus can not be fused. + Tasklet's body, the map parameter must now be replaced inside + the Tasklet's body. """ sdfg, trivial_map_entry, second_map_entry = _get_trivial_gpu_promotable( - "__out = __trivial_gpu_it" + tasklet_code="__out = __trivial_gpu_it", + trivial_map_range="2", + ) + state: dace.SDFGStae = sdfg.nodes()[0] + trivial_tasklet: dace_nodes.Tasklet = next( + iter( + out_edge.dst + for out_edge in state.out_edges(trivial_map_entry) + if isinstance(out_edge.dst, dace_nodes.Tasklet) + ) ) - org_trivial_map_params = list(trivial_map_entry.map.params) - org_second_map_params = list(second_map_entry.map.params) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) - assert ( - nb_runs == 0 - ), f"Expected that 'TrivialGPUMapPromoter' does not apply but it applied {nb_runs}." - trivial_map_params = trivial_map_entry.map.params - second_map_params = second_map_entry.map.params - assert ( - trivial_map_params == org_trivial_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert ( - second_map_params == org_second_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert sdfg.is_valid() + assert nb_runs == 1 + + expected_trivial_code = "__out = 2" + assert trivial_tasklet.code == expected_trivial_code + + +def test_set_gpu_properties(): + """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()`.""" + sdfg = dace.SDFG("gpu_properties_test") + state = sdfg.add_state(is_start_block=True) + + map_entries: dict[int, dace_nodes.MapEntry] = {} + for dim in [1, 2, 3]: + shape = (10,) * dim + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + map_entries[dim] = me + + sdfg.apply_gpu_transformations() + sdfg.validate() + + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size=(10, "11", 12), + launch_factor_2d=2, + block_size_2d=(2, 2, 2), + launch_bounds_3d=200, + ) + + map1, map2, map3 = (map_entries[d].map for d in [1, 2, 3]) + + assert len(map1.params) == 1 + assert map1.gpu_block_size == [10, 1, 1] + assert map1.gpu_launch_bounds == "0" + + assert len(map2.params) == 2 + assert map2.gpu_block_size == [2, 2, 1] + assert map2.gpu_launch_bounds == "8" + + assert len(map3.params) == 3 + assert map3.gpu_block_size == [10, 11, 12] + assert map3.gpu_launch_bounds == "200" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index c1e0ddd2f6..aac58eb32c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -29,7 +29,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np The k blocking transformation can be applied to the SDFG, however no node can be taken out. This is because how it is constructed. However, applying - some simplistic transformations this can be done. + some simplistic transformations will enable the transformation. """ sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -136,6 +136,83 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3)) +def _get_sdfg_with_empty_memlet( + first_tasklet_independent: bool, + only_empty_memlets: bool, +) -> tuple[ + dace.SDFG, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.Tasklet +]: + """Generates an SDFG with an empty tasklet. + + The map contains two (serial) tasklets, connected through an access node. + The first tasklet has an empty memlet that connects it to the map entry. + Depending on `first_tasklet_independent` the tasklet is either independent + or not. The second tasklet has an additional in connector that accesses an array. + + If `only_empty_memlets` is given then the second memlet will only depend + on the input of the first tasklet. However, since it is connected to the + map exit, it will be classified as dependent. + + Returns: + The function returns the SDFG, the map entry and the first tasklet (that + is either dependent or independent), the access node between the tasklets + and the second tasklet that is always dependent. + """ + sdfg = dace.SDFG(util.unique_name("empty_memlet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + sdfg.add_array("b", ("N", "M"), dace.float64, transient=False) + b = state.add_access("b") + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + + if not only_empty_memlets: + sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + a = state.add_access("a") + + # This is the first tasklet. + task1 = state.add_tasklet( + "task1", + inputs={}, + outputs={"__out0"}, + code="__out0 = 1.0" if first_tasklet_independent else "__out0 = j", + ) + + if only_empty_memlets: + task2 = state.add_tasklet( + "task2", inputs={"__in0"}, outputs={"__out0"}, code="__out0 = __in0 + 1.0" + ) + else: + task2 = state.add_tasklet( + "task2", inputs={"__in0", "__in1"}, outputs={"__out0"}, code="__out0 = __in0 + __in1" + ) + + # Now create the map + mentry, mexit = state.add_map("map", ndrange={"i": "0:N", "j": "0:M"}) + + if not only_empty_memlets: + state.add_edge(a, None, mentry, "IN_a", dace.Memlet("a[0:N, 0:M]")) + state.add_edge(mentry, "OUT_a", task2, "__in1", dace.Memlet("a[i, j]")) + + state.add_edge(task2, "__out0", mexit, "IN_b", dace.Memlet("b[i, j]")) + state.add_edge(mexit, "OUT_b", b, None, dace.Memlet("b[0:N, 0:M]")) + + state.add_edge(mentry, None, task1, None, dace.Memlet()) + state.add_edge(task1, "__out0", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, task2, "__in0", dace.Memlet("tmp[0]")) + + if not only_empty_memlets: + mentry.add_in_connector("IN_a") + mentry.add_out_connector("OUT_a") + mexit.add_in_connector("IN_b") + mexit.add_out_connector("OUT_b") + + sdfg.validate() + + return sdfg, mentry, task1, tmp, task2 + + def test_only_dependent(): """Just applying the transformation to the SDFG. @@ -152,11 +229,12 @@ def test_only_dependent(): ref = reff(a, b) # Apply the transformation - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 assert len(sdfg.states()) == 1 state = sdfg.states()[0] @@ -216,11 +294,12 @@ def test_intermediate_access_node(): assert np.allclose(ref, c) # Apply the transformation. - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 # Inspect if the SDFG was modified correctly. # We only inspect `tmp` which now has to be between the two maps. @@ -254,12 +333,12 @@ def test_chained_access() -> None: c[:] = 0 # Apply the transformation. - ret = sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) - assert ret == 1, f"Expected that the transformation was applied 1 time, but it was {ret}." + assert count == 1 # Now run the SDFG to see if it is still the same sdfg(a=a, b=b, c=c, M=M, N=N) @@ -305,3 +384,422 @@ def test_chained_access() -> None: assert isinstance(inner_tasklet, dace_nodes.Tasklet) assert inner_tasklet not in first_level_tasklets + + +def test_direct_map_exit_connection() -> dace.SDFG: + """Generates a SDFG with a mapped independent tasklet connected to the map exit. + + Because the tasklet is connected to the map exit it can not be independent. + """ + sdfg = dace.SDFG(util.unique_name("mapped_tasklet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_array("a", (10,), dace.float64, transient=False) + sdfg.add_array("b", (10, 30), dace.float64, transient=False) + tsklt, me, mx = state.add_mapped_tasklet( + name="comp", + map_ranges=dict(i=f"0:10", j=f"0:30"), + inputs=dict(__in0=dace.Memlet("a[i]")), + outputs=dict(__out=dace.Memlet("b[i, j]")), + code="__out = __in0 + 1", + external_edges=True, + ) + + assert all(out_edge.dst is tsklt for out_edge in state.out_edges(me)) + assert all(in_edge.src is tsklt for in_edge in state.in_edges(mx)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + assert all(isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(me)) + assert all(isinstance(in_edge.src, dace_nodes.MapExit) for in_edge in state.in_edges(mx)) + + +def test_empty_memlet_1(): + sdfg, mentry, itask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=True, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[itask] is mentry + assert scope_dict[tmp] is mentry + assert scope_dict[task2] is not mentry + assert scope_dict[task2] is not None + assert all( + isinstance(in_edge.src, dace_nodes.MapEntry) and in_edge.src is not mentry + for in_edge in state.in_edges(task2) + ) + + +def test_empty_memlet_2(): + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # Find the inner map entry + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def test_empty_memlet_3(): + # This is the only interesting case with only empty memlet. + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=True, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # The top map only has a single output, which is the empty edge, that is holding + # the inner map entry in the scope. + assert all(out_edge.data.is_empty() for out_edge in state.out_edges(mentry)) + assert state.in_degree(mentry) == 0 + assert state.out_degree(mentry) == 1 + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def _make_loop_blocking_sdfg_with_inner_map( + add_independent_part: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: + """ + Generate the SDFGs with an inner map. + + The SDFG has an inner map that is classified as dependent. If + `add_independent_part` is `True` then the SDFG has a part that is independent. + Note that everything is read from a single connector. + + Return: + The function will return the SDFG, the state and the map entry for the outer + and inner map. + """ + sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_map")) + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + + me_out, mx_out = state.add_map("outer_map", ndrange={"__i0": "0:10"}) + me_in, mx_in = state.add_map("inner_map", ndrange={"__i1": "0:10"}) + A, B = (state.add_access(name) for name in "AB") + tskl = state.add_tasklet( + "computation", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = __in1 + __in2" + ) + + if add_independent_part: + sdfg.add_array("C", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) + tmp, tmp2, C = (state.add_access(name) for name in ("tmp", "tmp2", "C")) + tskli = state.add_tasklet( + "independent_comp", inputs={"__field"}, outputs={"__out"}, code="__out = __field[1, 1]" + ) + + # construct the inner map of the map. + state.add_edge(A, None, me_out, "IN_A", dace.Memlet("A[0:10, 0:10]")) + me_out.add_in_connector("IN_A") + state.add_edge(me_out, "OUT_A", me_in, "IN_A", dace.Memlet("A[__i0, 0:10]")) + me_out.add_out_connector("OUT_A") + me_in.add_in_connector("IN_A") + state.add_edge(me_in, "OUT_A", tskl, "__in1", dace.Memlet("A[__i0, __i1]")) + me_in.add_out_connector("OUT_A") + + state.add_edge(me_out, "OUT_A", me_in, "IN_A1", dace.Memlet("A[__i0, 0:10]")) + me_in.add_in_connector("IN_A1") + state.add_edge(me_in, "OUT_A1", tskl, "__in2", dace.Memlet("A[__i0, 9 - __i1]")) + me_in.add_out_connector("OUT_A1") + + state.add_edge(tskl, "__out", mx_in, "IN_B", dace.Memlet("B[__i0, __i1]")) + mx_in.add_in_connector("IN_B") + state.add_edge(mx_in, "OUT_B", mx_out, "IN_B", dace.Memlet("B[__i0, 0:10]")) + mx_in.add_out_connector("OUT_B") + mx_out.add_in_connector("IN_B") + state.add_edge(mx_out, "OUT_B", B, None, dace.Memlet("B[0:10, 0:10]")) + mx_out.add_out_connector("OUT_B") + + # If requested add a part that is independent, i.e. is before the inner loop + if add_independent_part: + state.add_edge(me_out, "OUT_A", tskli, "__field", dace.Memlet("A[0:10, 0:10]")) + state.add_edge(tskli, "__out", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, tmp2, None, dace.Memlet("tmp2[0]")) + state.add_edge(tmp2, None, mx_out, "IN_tmp", dace.Memlet("C[__i0]")) + mx_out.add_in_connector("IN_tmp") + state.add_edge(mx_out, "OUT_tmp", C, None, dace.Memlet("C[0:10]")) + mx_out.add_out_connector("OUT_tmp") + + sdfg.validate() + return sdfg, state, me_out, me_in + + +def test_loop_blocking_inner_map(): + """ + Tests with an inner map, without an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(False) + assert all(oedge.dst is inner_map for oedge in state.out_edges(outer_map)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + assert all( + oedge.dst is not inner_map and isinstance(oedge.dst, dace_nodes.MapEntry) + for oedge in state.out_edges(outer_map) + ) + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst is inner_map for oedge in state.out_edges(inner_blocking_map)) + + +def test_loop_blocking_inner_map_with_independent_part(): + """ + Tests with an inner map with an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(True) + + # Find the parts that are independent. + itskl: dace_nodes.Tasklet = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.Tasklet) + ) + assert itskl.label == "independent_comp" + i_access_node: dace_nodes.AccessNode = next(oedge.dst for oedge in state.out_edges(itskl)) + assert i_access_node.data == "tmp" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst in {inner_blocking_map, itskl} for oedge in state.out_edges(outer_map)) + assert state.scope_dict()[i_access_node] is outer_map + assert all(oedge.dst is inner_blocking_map for oedge in state.out_edges(i_access_node)) + + +def _make_mixed_memlet_sdfg( + tskl1_independent: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.Tasklet]: + """ + Generates the SDFGs for the mixed Memlet tests. + + The SDFG that is generated has the following structure: + - `tsklt2`, is always dependent, it has an incoming connection from the + map entry, and an incoming, but empty, connection with `tskl1`. + - `tskl1` is connected to the map entry, depending on `tskl1_independent` + it is independent or dependent, it has an empty connection to `tskl2`, + thus it is sequenced before. + - Both have connection to other nodes down stream, but they are dependent. + + Returns: + A tuple containing the following objects. + - The SDFG. + - The SDFG state. + - The outer map entry node. + - `tskl1`. + - `tskl2`. + """ + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names_array = ["A", "B", "C"] + names_scalar = ["tmp1", "tmp2"] + for aname in names_array: + sdfg.add_array( + aname, + shape=((10,) if aname == "A" else (10, 10)), + dtype=dace.float64, + transient=False, + ) + for sname in names_scalar: + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, B, C, tmp1, tmp2 = (state.add_access(name) for name in names_array + names_scalar) + + me, mx = state.add_map("outer_map", ndrange={"i": "0:10", "j": "0:10"}) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1" if tskl1_independent else "__out = __in1 + j", + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 10.0", + ) + tskl3 = state.add_tasklet( + "tskl3", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + me.add_in_connector("IN_A") + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[i]")) + me.add_out_connector("OUT_A") + state.add_edge(tskl1, "__out", tmp1, None, dace.Memlet("tmp1[0]")) + + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10, 0:10]")) + me.add_in_connector("IN_B") + state.add_edge(me, "OUT_B", tskl2, "__in1", dace.Memlet("B[i, j]")) + me.add_out_connector("OUT_B") + state.add_edge(tskl2, "__out", tmp2, None, dace.Memlet("tmp2[0]")) + + # Add the empty Memlet that sequences `tskl1` before `tskl2`. + state.add_edge(tskl1, None, tskl2, None, dace.Memlet()) + + state.add_edge(tmp1, None, tskl3, "__in1", dace.Memlet("tmp1[0]")) + state.add_edge(tmp2, None, tskl3, "__in2", dace.Memlet("tmp2[0]")) + state.add_edge(tskl3, "__out", mx, "IN_C", dace.Memlet("C[i, j]")) + mx.add_in_connector("IN_C") + state.add_edge(mx, "OUT_C", C, None, dace.Memlet("C[0:10, 0:10]")) + mx.add_out_connector("OUT_C") + sdfg.validate() + + return (sdfg, state, me, tskl1, tskl2) + + +def _apply_and_run_mixed_memlet_sdfg(sdfg: dace.SDFG) -> None: + ref = { + "A": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "B": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + "C": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + } + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1, f"Expected one application, but git {count}" + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref) + + +def test_loop_blocking_mixked_memlets_1(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(True) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Ensure that `tskl1` is independent. + assert scope_dict[tskl1] is me + + # The output of `tskl1`, which is `tmp1` should also be classified as independent. + tmp1 = next(iter(edge.dst for edge in state.out_edges(tskl1) if not edge.data.is_empty())) + assert scope_dict[tmp1] is me + assert isinstance(tmp1, dace_nodes.AccessNode) + assert tmp1.data == "tmp1" + + # Find the inner map. + inner_map_entry: dace_nodes.MapEntry = scope_dict[tskl2] + assert inner_map_entry is not me and isinstance(inner_map_entry, dace_nodes.MapEntry) + inner_map_exit: dace_nodes.MapExit = state.exit_node(inner_map_entry) + + outer_scope = {tskl1, tmp1, inner_map_entry, inner_map_exit, mx} + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert node in outer_scope + else: + assert ( + (node is inner_map_exit) + or (isinstance(node, dace_nodes.AccessNode) and node.data == "tmp2") + or (isinstance(node, dace_nodes.Tasklet) and node.label in {"tskl2", "tskl3"}) + ) + + +def test_loop_blocking_mixked_memlets_2(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(False) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Because `tskl1` is now dependent, everything is now dependent. + inner_map_entry = scope_dict[tskl1] + assert isinstance(inner_map_entry, dace_nodes.MapEntry) + assert inner_map_entry is not me + + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert isinstance(node, dace_nodes.MapEntry) or (node is mx) + else: + assert scope_dict[node] is inner_map_entry diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py new file mode 100644 index 0000000000..1a4ce6d047 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -0,0 +1,264 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: + return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} + + +def _make_test_sdfg( + output_name: str = "G", + input_name: str = "G", + tmp_name: str = "T", + array_size: int | str = 10, + tmp_size: int | str | None = None, + map_range: tuple[int | str, int | str] | None = None, + tmp_to_glob_memlet: str | None = None, + in_offset: str | None = None, + out_offset: str | None = None, +) -> dace.SDFG: + if isinstance(array_size, str): + array_size = sdfg.add_symbol(array_size, dace.int32, find_new_name=True) + if tmp_size is None: + tmp_size = array_size + if map_range is None: + map_range = (0, array_size) + if tmp_to_glob_memlet is None: + tmp_to_glob_memlet = f"{tmp_name}[0:{array_size}] -> [0:{array_size}]" + elif tmp_to_glob_memlet[0] == "[": + tmp_to_glob_memlet = tmp_name + tmp_to_glob_memlet + if in_offset is None: + in_offset = "0" + if out_offset is None: + out_offset = in_offset + + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + names = {input_name, tmp_name, output_name} + for name in names: + sdfg.add_array( + name, + shape=((array_size,) if name != tmp_name else (tmp_size,)), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays[tmp_name].transient = True + + input_ac = state.add_access(input_name) + tmp_ac = state.add_access(tmp_name) + output_ac = state.add_access(output_name) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": f"{map_range[0]}:{map_range[1]}"}, + inputs={"__in1": dace.Memlet(data=input_ac.data, subset=f"__i0 + {in_offset}")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet(data=tmp_ac.data, subset=f"__i0 + {out_offset}")}, + input_nodes={input_ac}, + output_nodes={tmp_ac}, + external_edges=True, + ) + state.add_edge( + tmp_ac, + None, + output_ac, + None, + dace.Memlet(tmp_to_glob_memlet), + ) + sdfg.validate() + return sdfg + + +def _perform_test( + sdfg: dace.SDFG, + xform: gtx_transformations.GT4PyMapBufferElimination, + exp_count: int, + array_size: int = 10, +) -> None: + ref = { + name: np.array(np.random.rand(array_size), dtype=np.float64, copy=True) + for name, desc in sdfg.arrays.items() + if not desc.transient + } + if "array_size" in sdfg.symbols: + ref["array_size"] = array_size + + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + assert count == exp_count, f"Expected {exp_count} applications, but got {count}" + + if count == 0: + return + + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()), f"Failed for '{name}'." + + +def test_map_buffer_elimination_simple(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=True), + exp_count=1, + ) + + +def test_map_buffer_elimination_simple_2(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_simple_3(): + sdfg = _make_test_sdfg(input_name="A", output_name="O") + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_1(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_to_glob_memlet="[2:8] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_2(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [0:6]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_3(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_4(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[1:7] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_offset_5(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_size=6, + in_offset="0", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_not_apply(): + """Indirect accessing, because of this the double buffer is needed.""" + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + + names = ["A", "tmp", "idx"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.int32 if name == "tmp" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + tmp = state.add_access("tmp") + state.add_mapped_tasklet( + "indirect_accessing", + map_ranges={"__i0": "0:10"}, + inputs={ + "__field": dace.Memlet("A[0:10]"), + "__idx": dace.Memlet("idx[__i0]"), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_nedge(tmp, state.add_access("A"), dace.Memlet("tmp[0:10] -> [0:10]")) + + # TODO(phimuell): Update the transformation such that we can specify + # `assume_pointwise=True` and the test would still pass. + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index c9d467ba80..b468b80b8e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -58,14 +58,14 @@ def _make_serial_sdfg_1( inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, code="__out = __in0 + 1.0", outputs={"__out": dace.Memlet("tmp[__i0, __i1]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="second_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp": tmp}, + input_nodes={tmp}, inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -118,17 +118,14 @@ def _make_serial_sdfg_2( "__out0": dace.Memlet("tmp_1[__i0, __i1]"), "__out1": dace.Memlet("tmp_2[__i0, __i1]"), }, - output_nodes={ - "tmp_1": tmp_1, - "tmp_2": tmp_2, - }, + output_nodes={tmp_1, tmp_2}, external_edges=True, ) state.add_mapped_tasklet( name="first_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_1": tmp_1}, + input_nodes={tmp_1}, inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -137,7 +134,7 @@ def _make_serial_sdfg_2( state.add_mapped_tasklet( name="second_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_2": tmp_2}, + input_nodes={tmp_2}, inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, code="__out = __in0 - 3.0", outputs={"__out": dace.Memlet("c[__i0, __i1]")}, @@ -194,14 +191,14 @@ def _make_serial_sdfg_3( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("tmp[__i0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="indirect_access", map_ranges=[("__i0", f"0:{N_output}")], - input_nodes={"tmp": tmp}, + input_nodes={tmp}, inputs={ "__index": dace.Memlet("idx[__i0]"), "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), @@ -220,19 +217,19 @@ def test_exclusive_itermediate(): sdfg = _make_serial_sdfg_1(N) # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" not in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b"] ] assert len(intermediate_nodes) == 1 @@ -257,19 +254,19 @@ def test_shared_itermediate(): sdfg.arrays["tmp"].transient = False # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b", "tmp"] ] assert len(intermediate_nodes) == 1 @@ -291,21 +288,21 @@ def test_pure_output_node(): """Tests the path of a pure intermediate.""" N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 # The first fusion will only bring it down to two maps. sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -327,17 +324,17 @@ def test_array_intermediate(): """ N = 10 sdfg = _make_serial_sdfg_1(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 4 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 4 # Now perform the fusion sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(only_toplevel_maps=True), + gtx_transformations.MapFusionSerial(only_toplevel_maps=True), validate=True, validate_all=True, ) - map_entries = util._count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + map_entries = util.count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) scope = next(iter(sdfg.states())).scope_dict() assert len(map_entries) == 3 @@ -349,7 +346,7 @@ def test_array_intermediate(): # Find the access node that is the new intermediate node. inner_access_nodes: list[dace_nodes.AccessNode] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if scope[node] is not None ] assert len(inner_access_nodes) == 1 @@ -374,7 +371,7 @@ def test_interstate_transient(): """ N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 assert sdfg.number_of_nodes() == 1 # Now add the new state and the new output. @@ -393,15 +390,15 @@ def test_interstate_transient(): # Now apply the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) assert "tmp_1" in sdfg.arrays assert "tmp_2" not in sdfg.arrays assert sdfg.number_of_nodes() == 2 - assert util._count_nodes(head_state, dace_nodes.MapEntry) == 1 - assert util._count_nodes(new_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(new_state, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -430,7 +427,7 @@ def test_indirect_access(): c = np.empty(N_output) idx = np.random.randint(low=0, high=N_input, size=N_output, dtype=np.int32) sdfg = _make_serial_sdfg_3(N_input=N_input, N_output=N_output) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 def _ref(a, b, idx): tmp = a + b @@ -443,11 +440,11 @@ def _ref(a, b, idx): # Now "apply" the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 c[:] = -1.0 sdfg(a=a, b=b, idx=idx, c=c) @@ -455,5 +452,58 @@ def _ref(a, b, idx): def test_indirect_access_2(): - # TODO(phimuell): Index should be computed and that map should be fusable. - pass + """Indirect accesses, with non point wise input dependencies. + + Because `a` is used as input and output and `a` is indirectly accessed + the access to `a` can not be point wise so, fusing is not possible. + """ + sdfg = dace.SDFG(util.unique_name("indirect_access_sdfg_2")) + state = sdfg.add_state(is_start_block=True) + + names = ["a", "b", "idx", "tmp"] + + for name in names: + sdfg.add_array( + name=name, + shape=(10,), + dtype=dace.int32 if name == "idx" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + a_in, b, idx, tmp, a_out = (state.add_access(name) for name in (names + ["a"])) + + state.add_mapped_tasklet( + "indirect_access", + map_ranges={"__i0": "0:10"}, + inputs={ + "__idx": dace.Memlet("idx[__i0]"), + "__field": dace.Memlet("a[0:10]", volume=1), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + input_nodes={a_in, idx}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("tmp[__i0]"), + "__in2": dace.Memlet("b[__i0]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("a[__i0]")}, + input_nodes={tmp, b}, + output_nodes={a_out}, + external_edges=True, + ) + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.MapFusionSerial(), + validate=True, + validate_all=True, + ) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py new file mode 100644 index 0000000000..72efc2fe34 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -0,0 +1,100 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + + +def _perform_reorder_test( + sdfg: dace.SDFG, + leading_dim: list[str], + expected_order: list[str], +) -> None: + """Performs the reorder transformation and test it. + + If `expected_order` is the empty list, then the transformation should not apply. + """ + map_entries: list[dace.nodes.MapEntry] = util.count_nodes(sdfg, dace.nodes.MapEntry, True) + assert len(map_entries) == 1 + map_entry: dace.nodes.MapEntry = map_entries[0] + old_map_params = map_entry.map.params.copy() + + apply_count = sdfg.apply_transformations_repeated( + gtx_transformations.MapIterationOrder( + leading_dims=leading_dim, + ), + validate=True, + validate_all=True, + ) + new_map_params = map_entry.map.params.copy() + + if len(expected_order) == 0: + assert ( + apply_count == 0 + ), f"Expected that the transformation was not applied. New map order: {map_entry.map.params}" + return + else: + assert ( + apply_count > 0 + ), f"Expected that the transformation was applied. Old map order: {map_entry.map.params}; Expected order: {expected_order}" + assert len(expected_order) == len(new_map_params) + + assert ( + expected_order == new_map_params + ), f"Expected map order {expected_order} but got {new_map_params} instead." + + +def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: + """Generate an SDFG for the test.""" + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state: dace.SDFGState = sdfg.add_state("state", is_start_block=True) + dim = len(map_params) + for aname in ["a", "b"]: + sdfg.add_array(aname, shape=((4,) * dim), dtype=dace.float64, transient=False) + + state.add_mapped_tasklet( + "mapped_tasklet", + map_ranges=[(map_param, "0:4") for map_param in map_params], + inputs={"__in": dace.Memlet("a[" + ",".join(map_params) + "]")}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("b[" + ",".join(map_params) + "]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_map_order_1(): + sdfg = _make_test_sdfg(["EDim", "KDim", "VDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim", "EDim"]) + + +def test_map_order_2(): + sdfg = _make_test_sdfg(["VDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim"]) + + +def test_map_order_3(): + sdfg = _make_test_sdfg(["EDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "EDim"]) + + +def test_map_order_4(): + sdfg = _make_test_sdfg(["CDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], []) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py new file mode 100644 index 0000000000..7b39bc4e1d --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -0,0 +1,164 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_movable_tasklet( + outer_tasklet_code: str, +) -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.MapEntry +]: + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + + sdfg.add_scalar("outer_scalar", dtype=dace.float64, transient=True) + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + A, B, outer_scalar = (state.add_access(name) for name in ["A", "B", "outer_scalar"]) + + outer_tasklet = state.add_tasklet( + name="outer_tasklet", + inputs=set(), + outputs={"__out"}, + code=f"__out = {outer_tasklet_code}", + ) + state.add_edge(outer_tasklet, "__out", outer_scalar, None, dace.Memlet("outer_scalar[0]")) + + _, me, _ = state.add_mapped_tasklet( + "map", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("outer_scalar[0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + input_nodes={outer_scalar, A}, + output_nodes={B}, + ) + sdfg.validate() + + return sdfg, state, outer_tasklet, outer_scalar, me + + +def test_move_tasklet_inside_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_non_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + # By expanding the maps, we the memlet tree is no longer trivial. + sdfg.apply_transformations_repeated(dace_dataflow.MapExpansion) + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + me = None + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_two_inner_connector(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="32.2", + ) + mapped_tasklet = next( + iter(e.dst for e in state.out_edges(me) if isinstance(e.dst, dace_nodes.Tasklet)) + ) + + state.add_edge( + me, + f"OUT_{outer_scalar.data}", + mapped_tasklet, + "__in2", + dace.Memlet(f"{outer_scalar.data}[0]"), + ) + mapped_tasklet.add_in_connector("__in2") + mapped_tasklet.code.as_string = "__out = __in0 + __in1 + __in2" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 2 * (32.2) + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_outer_scalar_used_outside(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="22.6", + ) + sdfg.add_array("C", shape=(1,), dtype=dace.float64, transient=False) + state.add_edge(outer_scalar, None, state.add_access("C"), None, dace.Memlet("C[0]")) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + C = np.array(np.random.rand(1), dtype=np.float64, copy=True) + ref_C = 22.6 + ref_B = A + ref_C + + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + assert np.allclose(B, ref_B) + assert np.allclose(C, ref_C) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py index 96584b8273..8626cb8e07 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py @@ -68,7 +68,7 @@ def test_serial_map_promotion(): external_edges=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 1 assert len(map_entry_1d.map.range) == 1 assert len(map_entry_2d.map.params) == 2 @@ -83,7 +83,7 @@ def test_serial_map_promotion(): validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 2 assert len(map_entry_1d.map.range) == 2 assert len(map_entry_2d.map.params) == 2 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index ac88f4fef8..b82cecee98 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -14,7 +14,7 @@ @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[False], @@ -22,14 +22,14 @@ def _count_nodes( @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[True], ) -> list[dace_nodes.Node]: ... -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: bool = False, From 39fb949c2e7d0ff9ff4f1b9c3fff921ee8561086 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 5 Dec 2024 10:03:49 +0100 Subject: [PATCH 34/43] style[cartesian]: readability improvements and more type hints (#1752) This PR detaches a couple of cleanups in the dace backend from the in-progress gt4py/dace bridge: mostly readability improvements and some easy type hints. There's also the occasional unused variable / argument in here. --- src/gt4py/cartesian/backend/dace_backend.py | 16 +++--- .../gtc/dace/expansion_specification.py | 50 +++++++++---------- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 3 +- src/gt4py/cartesian/gtc/dace/utils.py | 2 +- .../test_code_generation.py | 2 +- 5 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index f49895a435..a6d28f5994 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -56,17 +56,17 @@ def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): - repldict = replace_strides( + replacement_dictionary = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map ) - sdfg.replace_dict(repldict) + sdfg.replace_dict(replacement_dictionary) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): - for k, v in repldict.items(): + for k, v in replacement_dictionary.items(): if k in node.symbol_mapping: node.symbol_mapping[k] = v - for k in repldict.keys(): + for k in replacement_dictionary.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k) @@ -143,7 +143,7 @@ def _to_device(sdfg: dace.SDFG, device: str) -> None: node.device = dace.DeviceType.GPU -def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): +def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): args_data = make_args_data_from_gtir(gtir_pipeline) # stencils without effect @@ -164,7 +164,7 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map) return sdfg -def _post_expand_trafos(sdfg: dace.SDFG): +def _post_expand_transformations(sdfg: dace.SDFG): # DaCe "standard" clean-up transformations sdfg.simplify(validate=False) @@ -355,7 +355,7 @@ def _unexpanded_sdfg(self): sdfg = OirSDFGBuilder().visit(oir_node) _to_device(sdfg, self.builder.backend.storage_info["device"]) - _pre_expand_trafos( + _pre_expand_transformations( self.builder.gtir_pipeline, sdfg, self.builder.backend.storage_info["layout_map"], @@ -371,7 +371,7 @@ def unexpanded_sdfg(self): def _expanded_sdfg(self): sdfg = self._unexpanded_sdfg() sdfg.expand_library_nodes() - _post_expand_trafos(sdfg) + _post_expand_transformations(sdfg) return sdfg def expanded_sdfg(self): diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index c716f1a103..af9a814843 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -107,7 +107,8 @@ def get_expansion_order_index(expansion_order, axis): for idx, item in enumerate(expansion_order): if isinstance(item, Iteration) and item.axis == axis: return idx - elif isinstance(item, Map): + + if isinstance(item, Map): for it in item.iterations: if it.kind == "contiguous" and it.axis == axis: return idx @@ -136,7 +137,9 @@ def _choose_loop_or_map(node, eo): return eo -def _order_as_spec(computation_node, expansion_order): +def _order_as_spec( + computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] +) -> List[ExpansionItem]: expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order) expansion_specification = [] for item in expansion_order: @@ -170,7 +173,7 @@ def _order_as_spec(computation_node, expansion_order): return expansion_specification -def _populate_strides(node, expansion_specification): +def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]): """Fill in `stride` attribute of `Iteration` and `Loop` dataclasses. For loops, stride is set to either -1 or 1, based on iteration order. @@ -185,10 +188,7 @@ def _populate_strides(node, expansion_specification): for it in iterations: if isinstance(it, Loop): if it.stride is None: - if node.oir_node.loop_order == common.LoopOrder.BACKWARD: - it.stride = -1 - else: - it.stride = 1 + it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1 else: if it.stride is None: if it.kind == "tiling": @@ -204,7 +204,7 @@ def _populate_strides(node, expansion_specification): it.stride = 1 -def _populate_storages(self, expansion_specification): +def _populate_storages(expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) innermost_axes = set(dcir.Axis.dims_3d()) tiled_axes = set() @@ -222,7 +222,7 @@ def _populate_storages(self, expansion_specification): tiled_axes.remove(it.axis) -def _populate_cpu_schedules(self, expansion_specification): +def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]): is_outermost = True for es in expansion_specification: if isinstance(es, Map): @@ -234,7 +234,7 @@ def _populate_cpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_gpu_schedules(self, expansion_specification): +def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]): # On GPU if any dimension is tiled and has a contiguous map in the same axis further in # pick those two maps as Device/ThreadBlock maps. If not, Make just device map with # default blocksizes @@ -267,16 +267,16 @@ def _populate_gpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_schedules(self, expansion_specification): +def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - _populate_gpu_schedules(self, expansion_specification) + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + _populate_gpu_schedules(expansion_specification) else: - _populate_cpu_schedules(self, expansion_specification) + _populate_cpu_schedules(expansion_specification) -def _collapse_maps_gpu(self, expansion_specification): +def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: def _union_map_items(last_item, next_item): if last_item.schedule == next_item.schedule: return ( @@ -307,7 +307,7 @@ def _union_map_items(last_item, next_item): ), ) - res_items = [] + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if not res_items or not isinstance(res_items[-1], Map): @@ -324,8 +324,8 @@ def _union_map_items(last_item, next_item): return res_items -def _collapse_maps_cpu(self, expansion_specification): - res_items = [] +def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if ( @@ -360,12 +360,12 @@ def _collapse_maps_cpu(self, expansion_specification): return res_items -def _collapse_maps(self, expansion_specification): - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - res_items = _collapse_maps_gpu(self, expansion_specification) +def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]): + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + res_items = _collapse_maps_gpu(expansion_specification) else: - res_items = _collapse_maps_cpu(self, expansion_specification) + res_items = _collapse_maps_cpu(expansion_specification) expansion_specification.clear() expansion_specification.extend(res_items) @@ -387,7 +387,7 @@ def make_expansion_order( _populate_strides(node, expansion_specification) _populate_schedules(node, expansion_specification) _collapse_maps(node, expansion_specification) - _populate_storages(node, expansion_specification) + _populate_storages(expansion_specification) return expansion_specification diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index f12c13cd0e..14448bb08e 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -123,6 +123,7 @@ def visit_VerticalLoop( state.add_edge( access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) ) + for field in access_collection.write_fields(): access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) library_node.add_out_connector("__out_" + field) @@ -131,8 +132,6 @@ def visit_VerticalLoop( library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) ) - return - def visit_Stencil(self, node: oir.Stencil, **kwargs): ctx = OirSDFGBuilder.SDFGContext(stencil=node) for param in node.params: diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 517e80ceb3..bd65861a49 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -40,7 +40,7 @@ def array_dimensions(array: dace.data.Array): return dims -def replace_strides(arrays, get_layout_map): +def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]: symbol_mapping = {} for array in arrays: dims = array_dimensions(array) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index e51b3ef09d..4609184547 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -421,7 +421,7 @@ def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): - @gtscript.stencil(backend, externals={"mord": 5}) + @gtscript.stencil(backend) def stencil(outp: gtscript.Field[np.float_]): with computation(PARALLEL), interval(...): cond = True From 8b6abc22fe07da99157afc3a03d7c3911651bff8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Dec 2024 08:18:35 +0100 Subject: [PATCH 35/43] refactor[next]: remove use of Fencil in tracing (eliminate `closure`) (#1772) --- src/gt4py/next/iterator/embedded.py | 11 ++--- src/gt4py/next/iterator/runtime.py | 9 +--- src/gt4py/next/iterator/tracing.py | 44 +++---------------- .../program_processors/runners/roundtrip.py | 1 - .../iterator_tests/test_builtins.py | 4 +- tests/next_tests/unit_tests/conftest.py | 1 + .../iterator_tests/test_runtime_domain.py | 9 ++-- 7 files changed, 22 insertions(+), 57 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3c63ffef30..13c64e264e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1706,8 +1706,10 @@ def impl(*iters: ItIterator): return impl -def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: - return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} +def _dimension_to_tag( + domain: runtime.CartesianDomain | runtime.UnstructuredDomain, +) -> dict[Tag, range]: + return {k.value: v for k, v in domain.items()} def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: @@ -1828,7 +1830,7 @@ def impl(*args): # TODO(havogt): after updating all tests to use the new program, # we should get rid of closure and move the implementation to this function - closure(_dimension_to_tag(domain), fun, out, list(args)) + closure(domain, fun, out, list(args)) return out return impl @@ -1839,9 +1841,8 @@ def index(axis: common.Dimension) -> common.Field: return IndexField(axis) -@runtime.closure.register(EMBEDDED) def closure( - domain_: Domain, + domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], out, #: MutableLocatedField, ins: list[common.Field | Scalar | tuple[common.Field | Scalar | tuple, ...]], diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index d42f961202..e47a6886ad 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"] +__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"] @dataclass(frozen=True) @@ -163,7 +163,7 @@ def impl(out, *inps): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) - closure(dom, self.fundef_dispatcher, out, [*inps]) + set_at(builtins.as_fieldop(self.fundef_dispatcher, dom)(*inps), dom, out) return impl @@ -208,11 +208,6 @@ def fundef(fun): return FundefDispatcher(fun) -@builtin_dispatch -def closure(*args): # TODO remove - return BackendNotSelectedError() - - @builtin_dispatch def set_at(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 6772d4b507..81e9551e5c 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -23,7 +23,6 @@ Lambda, NoneLiteral, OffsetLiteral, - StencilClosure, Sym, SymRef, ) @@ -202,9 +201,6 @@ def __bool__(self): class TracerContext: fundefs: ClassVar[List[FunctionDefinition]] = [] - closures: ClassVar[ - List[StencilClosure] - ] = [] # TODO(havogt): remove after refactoring to `Program` is complete, currently handles both programs and fencils body: ClassVar[List[itir.Stmt]] = [] @classmethod @@ -212,10 +208,6 @@ def add_fundef(cls, fun): if fun not in cls.fundefs: cls.fundefs.append(fun) - @classmethod - def add_closure(cls, closure): - cls.closures.append(closure) - @classmethod def add_stmt(cls, stmt): cls.body.append(stmt) @@ -225,23 +217,10 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): type(self).fundefs = [] - type(self).closures = [] type(self).body = [] iterator.builtins.builtin_dispatch.pop_key() -@iterator.runtime.closure.register(TRACING) -def closure(domain, stencil, output, inputs): - if hasattr(stencil, "__name__") and stencil.__name__ in iterator.builtins.__all__: - stencil = _s(stencil.__name__) - else: - stencil(*(_s(param) for param in inspect.signature(stencil).parameters)) - stencil = make_node(stencil) - TracerContext.add_closure( - StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - ) - - @iterator.runtime.set_at.register(TRACING) def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None: TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) @@ -328,19 +307,10 @@ def trace_fencil_definition( params = _make_fencil_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) - if TracerContext.closures: - return itir.FencilDefinition( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - closures=TracerContext.closures, - ) - else: - assert TracerContext.body - return itir.Program( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - declarations=[], # TODO - body=TracerContext.body, - ) + return itir.Program( + id=fun.__name__, + function_definitions=TracerContext.fundefs, + params=params, + declarations=[], # TODO + body=TracerContext.body, + ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 1dd568b95a..25eda5a2ed 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -46,7 +46,6 @@ class EmbeddedDSL(codegen.TemplatedGenerator): AxisLiteral = as_fmt("{value}") FunCall = as_fmt("{fun}({','.join(args)})") Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") - StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") FunctionDefinition = as_mako( """ @fundef diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 5e3a2fcd14..c0a4cd166d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -18,6 +18,7 @@ from gt4py.next.iterator import builtins as it_builtins from gt4py.next.iterator.builtins import ( and_, + as_fieldop, bool, can_deref, cartesian_domain, @@ -45,9 +46,8 @@ plus, shift, xor_, - as_fieldop, ) -from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 99bc44efa7..8f6d5787d3 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -53,6 +53,7 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), + (next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, True), (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 13e8637d1a..bf2df06bf2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -27,21 +27,20 @@ def foo(inp): dtype=None, ) +I = gtx.Dimension("I") + def test_deduce_domain(): assert isinstance(_deduce_domain({}, {}), CartesianDomain) assert isinstance(_deduce_domain(UnstructuredDomain(), {}), UnstructuredDomain) assert isinstance(_deduce_domain({}, {"foo": connectivity}), UnstructuredDomain) assert isinstance( - _deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), CartesianDomain + _deduce_domain(CartesianDomain([(I, range(1))]), {"foo": connectivity}), CartesianDomain ) -I = gtx.Dimension("I") - - def test_embedded_error_on_wrong_domain(): - dom = CartesianDomain([("I", range(1))]) + dom = CartesianDomain([(I, range(1))]) out = gtx.as_field([I], np.zeros(1)) with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): From 06813d54d9daec17bbac68aab32f6081c7f46b8e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Dec 2024 11:52:17 +0100 Subject: [PATCH 36/43] refactor[next]: remove all FencilDefinitions from tests (#1773) --- src/gt4py/next/iterator/ir.py | 10 +- .../iterator/transforms/symbol_ref_utils.py | 4 +- .../ffront_tests/test_decorator.py | 6 +- .../iterator_tests/test_pretty_parser.py | 36 +---- .../iterator_tests/test_pretty_printer.py | 36 +---- .../iterator_tests/test_type_inference.py | 151 +++++++----------- .../transforms_tests/test_symbol_ref_utils.py | 23 ++- .../gtfn_tests/test_gtfn_module.py | 41 ++--- 8 files changed, 96 insertions(+), 211 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 7098e9fa2e..6efee29362 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -189,17 +189,11 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu "scan", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } -# only used in `Program`` not `FencilDefinition` -# TODO(havogt): restructure after refactoring to GTIR -GTIR_BUILTINS = { - *BUILTINS, - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) -} - class FencilDefinition(Node, ValidatedSymbolTableTrait): id: Coerced[SymbolName] @@ -243,7 +237,7 @@ class Program(Node, ValidatedSymbolTableTrait): implicit_domain: bool = False _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(GTIR_BUILTINS) + Sym(id=name) for name in sorted(BUILTINS) ] # sorted for serialization stability diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 05163a3630..1765259a81 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -140,6 +140,4 @@ def collect_symbol_refs( def get_user_defined_symbols(symtable: dict[eve.SymbolName, itir.Sym]) -> set[str]: - return {str(sym) for sym in symtable.keys()} - { - str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_ - } + return {str(sym) for sym in symtable.keys()} - {str(n.id) for n in itir.Program._NODE_SYMBOLS_} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 47419c278b..45bf7428a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,10 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, (itir.Program, itir.FencilDefinition)) - assert isinstance( - testee.with_backend(cartesian_case.backend).itir, (itir.Program, itir.FencilDefinition) - ) + assert isinstance(testee.itir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index da4bea8874..bf47f997d6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -7,8 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.type_system import type_specifications as ts @@ -208,18 +208,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = "y ← (deref)(x) @ cartesian_domain();" - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - actual = pparse(testee) - assert actual == expected - - def test_set_at(): testee = "y @ cartesian_domain() ← x;" expected = ir.SetAt( @@ -262,28 +250,6 @@ def test_if_stmt(): assert actual == expected -# TODO(havogt): remove after refactoring to GTIR -def test_fencil_definition(): - testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - expected = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pparse(testee) - assert actual == expected - - def test_program(): testee = "f(d, x, y) {\n g = λ(x) → x;\n tmp = temporary(domain=cartesian_domain(), dtype=float64);\n y @ cartesian_domain() ← x;\n}" expected = ir.Program( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 69a45cf128..11f50dbf6d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -7,8 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.type_system import type_specifications as ts @@ -313,18 +313,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - expected = "y ← (deref)(x) @ cartesian_domain();" - actual = pformat(testee) - assert actual == expected - - def test_set_at(): testee = ir.SetAt( expr=ir.SymRef(id="x"), @@ -336,28 +324,6 @@ def test_set_at(): assert actual == expected -# TODO(havogt): remove after refactoring. -def test_fencil_definition(): - testee = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pformat(testee) - expected = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - assert actual == expected - - def test_program(): testee = ir.Program( id="f", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 65a5b5888d..7eb4e86adb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -23,13 +23,12 @@ ) from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - from next_tests.integration_tests.cases import ( C2E, E2V, V2E, E2VDim, + Edge, IDim, Ioff, JDim, @@ -37,11 +36,12 @@ Koff, V2EDim, Vertex, - Edge, - mesh_descriptor, exec_alloc_descriptor, + mesh_descriptor, unstructured_case, ) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh + bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) @@ -275,48 +275,35 @@ def test_cast_first_arg_inference(): assert result.type == float64_type -# TODO(tehrengruber): Rewrite tests to use itir.Program def test_cartesian_fencil_definition(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call(im.call("as_fieldop")(im.ref("deref"), cartesian_domain))( + im.ref("inp") + ), domain=cartesian_domain, - stencil=im.ref("deref"), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[IDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_unstructured_fencil_definition(): @@ -326,44 +313,34 @@ def test_unstructured_fencil_definition(): im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_edge_k_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain + ) + )(im.ref("inp")), domain=unstructured_domain, - stencil=im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[Vertex, KDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[Vertex, KDim], - defined_dims=[Edge, KDim], - element_type=float64_type, - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_vertex_k_field, - inputs=[float_edge_k_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_edge_k_field, "out": float_vertex_k_field}, closures=[closure_type] + program_type = it_ts.ProgramType( + params={"inp": float_edge_k_field, "out": float_vertex_k_field} ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[Vertex, KDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_vertex_k_field + assert result.body[0].target.type == float_vertex_k_field def test_function_definition(): @@ -371,45 +348,29 @@ def test_function_definition(): im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[ itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")), itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")), ], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=cartesian_domain, - stencil=im.ref("bar"), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.call(im.call("as_fieldop")(im.ref("bar"), cartesian_domain))(im.ref("inp")), + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_fencil_with_nb_field_input(): @@ -419,24 +380,30 @@ def test_fencil_with_nb_field_input(): im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_vertex_v2e_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=unstructured_domain, - stencil=im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), + unstructured_domain, + ) + )(im.ref("inp")), + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - assert result.closures[0].stencil.expr.args[0].type == float64_list_type - assert result.closures[0].stencil.type.returns == float64_type + stencil = result.body[0].expr.fun.args[0] + assert stencil.expr.args[0].type == float64_list_type + assert stencil.type.returns == float64_type def test_program_tuple_setat_short_target(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py index 0c118ff6dc..c162860c7c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py @@ -6,28 +6,23 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass -from typing import Optional -from gt4py import eve from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.symbol_ref_utils import ( - collect_symbol_refs, - get_user_defined_symbols, -) +from gt4py.next.iterator.transforms.symbol_ref_utils import get_user_defined_symbols def test_get_user_defined_symbols(): - ir = itir.FencilDefinition( + domain = itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]) + ir = itir.Program( id="foo", function_definitions=[], params=[itir.Sym(id="target_symbol")], - closures=[ - itir.StencilClosure( - domain=itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]), - stencil=itir.SymRef(id="deref"), - output=itir.SymRef(id="target_symbol"), - inputs=[], + declarations=[], + body=[ + itir.SetAt( + expr=itir.Lambda(params=[itir.Sym(id="foo")], expr=itir.SymRef(id="foo")), + domain=domain, + target=itir.SymRef(id="target_symbol"), ) ], ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e64bd8a57d..0586d48703 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -6,11 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -import pytest import copy -import diskcache +import diskcache +import numpy as np +import pytest import gt4py.next as gtx from gt4py.next.iterator import ir as itir @@ -19,18 +19,17 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation -from next_tests.integration_tests import cases -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim +from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case - from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + KDim, exec_alloc_descriptor, ) @pytest.fixture -def fencil_example(): +def program_example(): IDim = gtx.Dimension("I") params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] param_types = [type_translation.from_value(param) for param in params] @@ -48,7 +47,7 @@ def fencil_example(): ) ], ) - fencil = itir.FencilDefinition( + program = itir.Program( id="example", params=[im.sym(name, type_) for name, type_ in zip(("buf", "sc"), param_types)], function_definitions=[ @@ -58,20 +57,22 @@ def fencil_example(): expr=im.literal("1", "float32"), ) ], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call(im.call("as_fieldop")(itir.SymRef(id="stencil"), domain))( + itir.SymRef(id="buf"), itir.SymRef(id="sc") + ), domain=domain, - stencil=itir.SymRef(id="stencil"), - output=itir.SymRef(id="buf"), - inputs=[itir.SymRef(id="buf"), itir.SymRef(id="sc")], + target=itir.SymRef(id="buf"), ) ], ) - return fencil, params + return program, params -def test_codegen(fencil_example): - fencil, parameters = fencil_example +def test_codegen(program_example): + fencil, parameters = program_example module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, @@ -85,8 +86,8 @@ def test_codegen(fencil_example): assert module.language is languages.CPP -def test_hash_and_diskcache(fencil_example, tmp_path): - fencil, parameters = fencil_example +def test_hash_and_diskcache(program_example, tmp_path): + fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, args=arguments.CompileTimeArgs.from_concrete_no_size( @@ -129,8 +130,8 @@ def test_hash_and_diskcache(fencil_example, tmp_path): ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) -def test_gtfn_file_cache(fencil_example): - fencil, parameters = fencil_example +def test_gtfn_file_cache(program_example): + fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, args=arguments.CompileTimeArgs.from_concrete_no_size( From 2c48858ff00f5f7ac2786f945bbf6bca60bfd4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:20:31 +0100 Subject: [PATCH 37/43] feat[dace]: Restirct Loop Blocking (#1775) Made it possible to disable loop blocking if there are no independent nodes. --- .../transformations/auto_optimize.py | 6 ++- .../transformations/loop_blocking.py | 24 ++++++++- .../test_loop_blocking.py | 49 +++++++++++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index bc1d21ca05..4a06d2f416 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -32,6 +32,7 @@ def gt_auto_optimize( gpu_block_size: Optional[Sequence[int | str] | str] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, + blocking_only_if_independent_nodes: Optional[bool] = None, reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, @@ -90,6 +91,9 @@ def gt_auto_optimize( one for all. blocking_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. + blocking_only_if_independent_nodes: If `True` only apply loop blocking if + there are independent nodes in the Map, see the `require_independent_nodes` + option of the `LoopBlocking` transformation. reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` @@ -101,7 +105,6 @@ def gt_auto_optimize( validate: Perform validation during the steps. validate_all: Perform extensive validation. - Note: For identifying symbols that can be treated as compile time constants `gt_find_constant_arguments()` function can be used. @@ -227,6 +230,7 @@ def gt_auto_optimize( gtx_transformations.LoopBlocking( blocking_size=blocking_size, blocking_parameter=blocking_dim, + require_independent_nodes=blocking_only_if_independent_nodes, ), validate=validate, validate_all=validate_all, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index d401c06f15..27b6c68072 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -36,12 +36,16 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): What makes this transformation different from simple blocking, is that the inner map will not just be inserted right after the outer Map. Instead the transformation will first identify all nodes that does not depend - on the blocking parameter `I` and relocate them between the outer and inner map. - Thus these operations will only be performed once, per inner loop. + on the blocking parameter `I`, called independent nodes and relocate them + between the outer and inner map. Note that an independent node must be connected + to the MapEntry or another independent node. + Thus these operations will only be performed once, per outer loop iteration. Args: blocking_size: The size of the block, denoted as `B` above. blocking_parameter: On which parameter should we block. + require_independent_nodes: If `True` only apply loop blocking if the Map + actually contains independent nodes. Defaults to `False`. Todo: - Modify the inner map such that it always starts at zero. @@ -59,6 +63,12 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) + require_independent_nodes = dace_properties.Property( + dtype=bool, + default=False, + desc="If 'True' then blocking is only applied if there are independent nodes.", + ) + # Set of nodes that are independent of the blocking parameter. _independent_nodes: Optional[set[dace_nodes.AccessNode]] _dependent_nodes: Optional[set[dace_nodes.AccessNode]] @@ -69,6 +79,7 @@ def __init__( self, blocking_size: Optional[int] = None, blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, + require_independent_nodes: Optional[bool] = None, ) -> None: super().__init__() if isinstance(blocking_parameter, gtx_common.Dimension): @@ -77,6 +88,8 @@ def __init__( self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size + if require_independent_nodes is not None: + self.require_independent_nodes = require_independent_nodes self._independent_nodes = None self._dependent_nodes = None @@ -250,6 +263,9 @@ def partition_map_output( member variables are updated. If the partition does not exists the function will return `False` and the respective member variables will be `None`. + The function will honor `self.require_independent_nodes`. Thus if no independent + nodes were found the function behaves as if the partition does not exist. + Args: state: The state on which we operate. sdfg: The SDFG in which we operate on. @@ -295,6 +311,10 @@ def partition_map_output( if not found_new_independent_node: break + if self.require_independent_nodes and len(self._independent_nodes) == 0: + self._independent_nodes = None + return False + # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. self._dependent_nodes = { diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index aac58eb32c..67bec9c09f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -803,3 +803,52 @@ def test_loop_blocking_mixked_memlets_2(): assert isinstance(node, dace_nodes.MapEntry) or (node is mx) else: assert scope_dict[node] is inner_map_entry + + +def test_loop_blocking_no_independent_nodes(): + import dace + + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names = ["A", "B"] + for aname in names: + sdfg.add_array( + aname, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "fully_dependent_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + + # Because there is nothing that is independent the transformation will + # not apply if `require_independent_nodes` is enabled. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=True, + ), + validate=True, + validate_all=True, + ) + assert count == 0 + + # But it will apply once this requirement is lifted. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 From 54f176f1e77536c4911d56ebaff35a53a7d37d6d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Dec 2024 15:21:19 +0100 Subject: [PATCH 38/43] refactor[next]: remove FencilDefinition definition (#1774) After this commit `FencilDefinition`s are completely removed. Next step could be to cleanup `itir` -> `gtir` everywhere. --- docs/user/next/advanced/HackTheToolchain.md | 2 +- src/gt4py/next/backend.py | 18 +- src/gt4py/next/ffront/decorator.py | 5 +- src/gt4py/next/ffront/foast_to_itir.py | 512 --------------- src/gt4py/next/ffront/past_to_itir.py | 115 +--- src/gt4py/next/iterator/ir.py | 37 +- src/gt4py/next/iterator/pretty_parser.py | 21 - src/gt4py/next/iterator/pretty_printer.py | 41 -- src/gt4py/next/iterator/tracing.py | 12 +- .../next/iterator/transforms/__init__.py | 4 +- .../iterator/transforms/collapse_tuple.py | 2 +- src/gt4py/next/iterator/transforms/cse.py | 4 +- .../iterator/transforms/fencil_to_program.py | 31 - .../next/iterator/transforms/pass_manager.py | 22 +- .../iterator/transforms/program_to_fencil.py | 31 - .../transforms/prune_closure_inputs.py | 44 -- .../iterator/transforms/symbol_ref_utils.py | 2 +- .../next/iterator/type_system/inference.py | 51 +- .../type_system/type_specifications.py | 24 - src/gt4py/next/otf/stages.py | 4 +- .../codegens/gtfn/gtfn_module.py | 14 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../program_processors/formatters/gtfn.py | 2 +- .../program_processors/formatters/lisp.py | 67 -- .../formatters/pretty_print.py | 2 +- .../program_processors/program_formatter.py | 8 +- .../runners/dace_fieldview/workflow.py | 2 +- .../next/program_processors/runners/gtfn.py | 2 +- .../program_processors/runners/roundtrip.py | 10 +- .../ffront_tests/test_decorator.py | 6 +- .../test_temporaries_with_sizes.py | 6 +- tests/next_tests/unit_tests/conftest.py | 1 - .../ffront_tests/test_foast_to_itir.py | 598 ------------------ .../ffront_tests/test_past_to_gtir.py | 11 +- .../ffront_tests/test_past_to_itir.py | 214 ------- .../transforms_tests/test_domain_inference.py | 21 +- .../test_prune_closure_inputs.py | 68 -- .../dace_tests/test_gtir_to_sdfg.py | 1 + 38 files changed, 91 insertions(+), 1926 deletions(-) delete mode 100644 src/gt4py/next/ffront/foast_to_itir.py delete mode 100644 src/gt4py/next/iterator/transforms/fencil_to_program.py delete mode 100644 src/gt4py/next/iterator/transforms/program_to_fencil.py delete mode 100644 src/gt4py/next/iterator/transforms/prune_closure_inputs.py delete mode 100644 src/gt4py/next/program_processors/formatters/lisp.py delete mode 100644 tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py delete mode 100644 tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 029833cb7d..358f6e8d0d 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -15,7 +15,7 @@ from gt4py import eve ```python cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( - past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False) + past_to_itir=gtx.ffront.past_to_itir.past_to_gtir_factory(cached=False) ) ``` diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e223d7771c..e075422ca3 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -16,7 +16,6 @@ from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( foast_to_gtir, - foast_to_itir, foast_to_past, func_to_foast, func_to_past, @@ -41,7 +40,7 @@ ARGS: typing.TypeAlias = arguments.JITArgs CARG: typing.TypeAlias = arguments.CompileTimeArgs -IT_PRG: typing.TypeAlias = itir.FencilDefinition | itir.Program +IT_PRG: typing.TypeAlias = itir.Program INPUT_DATA: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PRG | IT_PRG @@ -93,7 +92,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field( - default_factory=past_to_itir.past_to_itir_factory + default_factory=past_to_itir.past_to_gtir_factory ) def step_order(self, inp: INPUT_PAIR) -> list[str]: @@ -126,7 +125,7 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: ) case PRG(): steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"]) - case itir.FencilDefinition() | itir.Program(): + case itir.Program(): pass case _: raise ValueError("Unexpected input.") @@ -135,17 +134,6 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() -# FIXME[#1582](havogt): remove after refactoring to GTIR -# note: this step is deliberately placed here, such that the cache is shared -_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) -LEGACY_TRANSFORMS: Transforms = Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), - foast_to_itir=_foast_to_itir_step, - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=_foast_to_itir_step - ), -) - # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 61756f30c9..d187095019 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,7 +34,6 @@ from gt4py.next.ffront import ( field_operator_ast as foast, foast_to_gtir, - foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -186,7 +185,7 @@ def _all_closure_vars(self) -> dict[str, Any]: return transform_utils._get_closure_vars_recursively(self.past_stage.closure_vars) @functools.cached_property - def itir(self) -> itir.FencilDefinition: + def gtir(self) -> itir.Program: no_args_past = toolchain.CompilableProgram( data=ffront_stages.PastProgramDefinition( past_node=self.past_stage.past_node, @@ -561,7 +560,7 @@ def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: # a different backend than the one of the program that calls this field operator. Just use # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return foast_to_itir.foast_to_itir(self.foast_stage) + return foast_to_gtir.foast_to_gtir(self.foast_stage) # FIXME[#1582](tehrengruber): remove after refactoring to GTIR def __gt_gtir__(self) -> itir.FunctionDefinition: diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py deleted file mode 100644 index 538b0f3ddb..0000000000 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ /dev/null @@ -1,512 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# FIXME[#1582](havogt): remove after refactoring to GTIR - -import dataclasses -from typing import Any, Callable, Optional - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Never -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.ffront import ( - dialect_ast_enums, - fbuiltins, - field_operator_ast as foast, - lowering_utils, - stages as ffront_stages, - type_specifications as ts_ffront, -) -from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES -from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.ffront.stages import AOT_FOP, FOP -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts - - -def foast_to_itir(inp: FOP) -> itir.Expr: - """ - Lower a FOAST field operator node to Iterator IR. - - See the docstring of `FieldOperatorLowering` for details. - """ - return FieldOperatorLowering.apply(inp.foast_node) - - -def foast_to_itir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: - """Wrap `foast_to_itir` into a chainable and, optionally, cached workflow step.""" - wf = foast_to_itir - if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) - return wf - - -def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: - """Wrap the `foast_to_itir` workflow step into an adapter to fit into backend transform workflows.""" - return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) - - -def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node_type): - return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) - return lambda x: x - - -@dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): - """ - Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). - - The strategy is to lower every expression to lifted stencils, - i.e. taking iterators and returning iterator. - - Examples - -------- - >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser - >>> from gt4py.next import Field, Dimension, float64 - >>> - >>> IDim = Dimension("IDim") - >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp - >>> - >>> parsed = FieldOperatorParser.apply_to_function(fieldop) - >>> lowered = FieldOperatorLowering.apply(parsed) - >>> type(lowered) - - >>> lowered.id - SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS - [Sym(id=SymbolName('inp'))] - """ - - uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) - - @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) - - def visit_FunctionDefinition( - self, node: foast.FunctionDefinition, **kwargs: Any - ) -> itir.FunctionDefinition: - params = self.visit(node.params) - return itir.FunctionDefinition( - id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) - ) # `expr` is a lifted stencil - - def visit_FieldOperator( - self, node: foast.FieldOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - - new_body = func_definition.expr - - return itir.FunctionDefinition( - id=func_definition.id, params=func_definition.params, expr=new_body - ) - - def visit_ScanOperator( - self, node: foast.ScanOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - # note: we don't need the axis here as this is handled by the program - # decorator - assert isinstance(node.type, ts_ffront.ScanOperatorType) - - # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. - # In iterator IR we didn't properly specify if this is legal, - # however after lift-inlining the expressions are transformed back to literals. - forward = im.deref(self.visit(node.forward, **kwargs)) - init = lowering_utils.process_elements( - im.deref, self.visit(node.init, **kwargs), node.init.type - ) - - # lower definition function - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = im.let( - func_definition.params[0].id, - # promote carry to iterator of tuples - # (this is the only place in the lowering were a variable is captured in a lifted lambda) - lowering_utils.to_tuples_of_iterator( - im.promote_to_const_iterator(func_definition.params[0].id), - [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] - ), - )( - # the function itself returns a tuple of iterators, deref element-wise - lowering_utils.process_elements( - im.deref, func_definition.expr, node.type.definition.returns - ) - ) - - stencil_args: list[itir.Expr] = [] - assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args - for param, arg_type in zip( - func_definition.params[1:], - [*node.type.definition.pos_or_kw_args.values()][1:], - strict=True, - ): - if isinstance(arg_type, ts.TupleType): - # convert into iterator of tuples - stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) - - new_body = im.let( - param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) - )(new_body) - else: - stencil_args.append(im.ref(param.id)) - - definition = itir.Lambda(params=func_definition.params, expr=new_body) - - body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - - return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) - - def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: - raise AssertionError("Statements must always be visited in the context of a function.") - - def visit_Return( - self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return self.visit(node.value, **kwargs) - - def visit_BlockStmt( - self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - for stmt in reversed(node.stmts): - inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) - assert inner_expr - return inner_expr - - def visit_IfStmt( - self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) - assert ( - isinstance(node.condition.type, ts.ScalarType) - and node.condition.type.kind == ts.ScalarKind.BOOL - ) - - cond = self.visit(node.condition, **kwargs) - - return_kind: StmtReturnKind = deduce_stmt_return_kind(node) - - common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols - - if return_kind is StmtReturnKind.NO_RETURN: - # pack the common symbols into a tuple - common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) - - # apply both branches and extract the common symbols through the prepared tuple - true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) - - # unpack the common symbols' tuple for `inner_expr` - for i, sym in enumerate(common_symbols.keys()): - inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) - - # here we assume neither branch returns - return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - inner_expr - ) - elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: - common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) - common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) - - # wrap the inner expression in a lambda function. note that this increases the - # operation count if both branches are evaluated. - inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") - inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) - inner_expr = im.call(inner_expr_name)(*common_symrefs) - - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.let(inner_expr_name, inner_expr_evaluator)( - im.if_(im.deref(cond), true_branch, false_branch) - ) - - assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN - - # note that we do not duplicate `inner_expr` here since if both branches - # return, `inner_expr` is ignored. - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.if_(im.deref(cond), true_branch, false_branch) - - def visit_Assign( - self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( - inner_expr - ) - - def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - return im.sym(node.id) - - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: - return im.ref(node.id) - - def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: - return im.tuple_get(node.index, self.visit(node.value, **kwargs)) - - def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: - return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: - # TODO(tehrengruber): extend iterator ir to support unary operators - dtype = type_info.extract_dtype(node.type) - if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: - if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._lower_and_map("not_", node.operand) - - return self._lower_and_map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) - - def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: - op = "if_" - args = (node.condition, node.true_expr, node.false_expr) - lowered_args: list[itir.Expr] = [ - lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) - for arg in args - ] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [ - promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args) - ] - op = im.call("map_")(op) - - return lowering_utils.to_tuples_of_iterator( - im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type - ) - - def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - current_expr = self.visit(node.func, **kwargs) - - for arg in node.args: - match arg: - # `field(Off[idx])` - case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): - current_expr = im.lift( - im.lambda_("it")(im.deref(im.shift(offset_name, offset_index)("it"))) - )(current_expr) - # `field(Dim + idx)` - case foast.BinOp( - op=dialect_ast_enums.BinaryOperator.ADD - | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), - right=foast.Constant(value=offset_index), - ): - if arg.op == dialect_ast_enums.BinaryOperator.SUB: - offset_index *= -1 - current_expr = im.lift( - # TODO(SF-N): we rely on the naming-convention that the cartesian dimensions - # are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the - # offset provider. This is a rather unclean solution and should be - # improved. - im.lambda_("it")( - im.deref( - im.shift( - common.dimension_to_implicit_offset(dimension), offset_index - )("it") - ) - ) - )(current_expr) - # `field(Off)` - case foast.Name(id=offset_name): - # only a single unstructured shift is supported so returning here is fine even though we - # are in a loop. - assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern - return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # `field(as_offset(Off, offset_field))` - case foast.Call(func=foast.Name(id="as_offset")): - func_args = arg - # TODO(tehrengruber): Use type system to deduce the offset dimension instead of - # (e.g. to allow aliasing) - offset_dim = func_args.args[0] - assert isinstance(offset_dim, foast.Name) - offset_it = self.visit(func_args.args[1], **kwargs) - current_expr = im.lift( - im.lambda_("it", "offset")( - im.deref(im.shift(offset_dim.id, im.deref("offset"))("it")) - ) - )(current_expr, offset_it) - case _: - raise FieldOperatorLoweringError("Unexpected shift arguments!") - - return current_expr - - def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if type_info.type_class(node.func.type) is ts.FieldType: - return self._visit_shift(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in ( - FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES - ): - visitor = getattr(self, f"_visit_{node.func.id}") - return visitor(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: - return self._visit_type_constr(node, **kwargs) - elif isinstance( - node.func.type, - (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - ): - # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # and keyword arguments and use the unique order as given in the function signature. - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - node.func.type, - self.visit(node.args, **kwargs), - self.visit(node.kwargs, **kwargs), - use_signature_ordering=True, - ) - result = im.call(self.visit(node.func, **kwargs))( - *lowered_args, *lowered_kwargs.values() - ) - - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - result = lowering_utils.to_tuples_of_iterator( - result, node.func.type.definition.returns - ) - - return result - - raise AssertionError( - f"Call to object of type '{type(node.func.type).__name__}' not understood." - ) - - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, new_type = node.args[0], node.args[1].id - return lowering_utils.process_elements( - lambda x: im.promote_to_lifted_stencil( - im.lambda_("it")(im.call("cast_")("it", str(new_type))) - )(x), - self.visit(obj, **kwargs), - obj.type, - ) - - def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - condition, true_value, false_value = node.args - - lowered_condition = self.visit(condition, **kwargs) - return lowering_utils.process_elements( - lambda tv, fv, types: _map( - "if_", (lowered_condition, tv, fv), (condition.type, *types) - ), - [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], - node.type, - (node.args[1].type, node.args[2].type), - ) - - _visit_concat_where = _visit_where - - def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) - - def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) - - def _make_reduction_expr( - self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any - ) -> itir.Expr: - # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) - it = self.visit(node.args[0], **kwargs) - assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return im.promote_to_lifted_stencil(val)(it) - - def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) - - def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - min_value, _ = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(min_value), dtype) - return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) - - def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - _, max_value = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(max_value), dtype) - return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) - - def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - el = node.args[0] - node_kind = self.visit(node.type).kind.name.lower() - source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] - target_type = fbuiltins.BUILTINS[node_kind] - - if isinstance(el, foast.Constant): - val = source_type(el.value) - elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): - operand = source_type(el.operand.value) - val = eval(f"lambda arg: {el.op}arg")(operand) - else: - raise FieldOperatorLoweringError( - f"Type cast only supports literal arguments, {node.type} not supported." - ) - val = target_type(val) - - return im.promote_to_const_iterator(im.literal(str(val), node_kind)) - - def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # the following constructs work if they are removed by inlining. - if isinstance(type_, ts.TupleType): - return im.make_tuple( - *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - ) - elif isinstance(type_, ts.ScalarType): - typename = type_.kind.name.lower() - return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type '{type_}'.") - - def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: - return self._make_literal(node.value, node.type) - - def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - return _map( - op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) - ) - - -def _map( - op: itir.Expr | str, - lowered_args: tuple, - original_arg_types: tuple[ts.TypeSpec, ...], -) -> itir.FunCall: - """ - Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists. - """ - if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): - lowered_args = tuple( - promote_to_list(arg_type)(larg) - for arg_type, larg in zip(original_arg_types, lowered_args) - ) - op = im.call("map_")(op) - - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - - -class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c0348bb5c6..4ec12bb76b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import functools from typing import Any, Optional, cast import devtools @@ -19,7 +18,6 @@ from gt4py.next.ffront import ( fbuiltins, gtcallable, - lowering_utils, program_ast as past, stages as ffront_stages, transform_utils, @@ -32,10 +30,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove # dependency as soon as column axis can be deduced from ITIR in consumers of the CompilableProgram. -def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgram: +def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: """ Lower a PAST program definition to Iterator IR. @@ -59,7 +56,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ... column_axis=None, ... ) - >>> itir_copy = past_to_itir( + >>> itir_copy = past_to_gtir( ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) ... ) @@ -67,7 +64,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra copy_program >>> print(type(itir_copy.data)) - + """ all_closure_vars = transform_utils._get_closure_vars_recursively(inp.data.closure_vars) offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( @@ -88,13 +85,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # making this step aware of the toolchain it is called by (it can be part of multiple). lowered_funcs = [] for gt_callable in gt_callables: - if to_gtir: - lowered_funcs.append(gt_callable.__gt_gtir__()) - else: - lowered_funcs.append(gt_callable.__gt_itir__()) + lowered_funcs.append(gt_callable.__gt_gtir__()) itir_program = ProgramLowering.apply( - inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir + inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) if config.DEBUG or inp.data.debug: @@ -106,11 +100,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ) -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR -def past_to_itir_factory( - cached: bool = True, to_gtir: bool = True +def past_to_gtir_factory( + cached: bool = True, ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: - wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) + wf = workflow.make_step(past_to_gtir) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) return wf @@ -190,7 +183,7 @@ class ProgramLowering( ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP - + >>> lowered.id # doctest: +SKIP SymbolName('program') >>> lowered.params # doctest: +SKIP @@ -198,7 +191,6 @@ class ProgramLowering( """ grid_type: common.GridType - to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -209,11 +201,8 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, - to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR - ) -> itir.FencilDefinition: - return cls(grid_type=grid_type, to_gtir=to_gtir).visit( - node, function_definitions=function_definitions - ) + ) -> itir.Program: + return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: """Generate symbols for each field param and dimension.""" @@ -246,7 +235,7 @@ def visit_Program( *, function_definitions: list[itir.FunctionDefinition], **kwargs: Any, - ) -> itir.FencilDefinition | itir.Program: + ) -> itir.Program: # The ITIR does not support dynamically getting the size of a field. As # a workaround we add additional arguments to the fencil definition # containing the size of all fields. The caller of a program is (e.g. @@ -259,27 +248,17 @@ def visit_Program( params = params + self._gen_size_params_from_program(node) implicit_domain = True - if self.to_gtir: - set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] - return itir.Program( - id=node.id, - function_definitions=function_definitions, - params=params, - declarations=[], - body=set_ats, - implicit_domain=implicit_domain, - ) - else: - closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] - return itir.FencilDefinition( - id=node.id, - function_definitions=function_definitions, - params=params, - closures=closures, - implicit_domain=implicit_domain, - ) + set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] + return itir.Program( + id=node.id, + function_definitions=function_definitions, + params=params, + declarations=[], + body=set_ats, + implicit_domain=implicit_domain, + ) - def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: + def _visit_field_operator_call(self, node: past.Call, **kwargs: Any) -> itir.SetAt: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) @@ -303,56 +282,6 @@ def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir. target=output, ) - # FIXME[#1582](havogt): remove after refactoring to GTIR - def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: - assert isinstance(node.kwargs["out"].type, ts.TypeSpec) - assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) - - node_kwargs = {**node.kwargs} - domain = node_kwargs.pop("domain", None) - output, lowered_domain = self._visit_stencil_call_out_arg( - node_kwargs.pop("out"), domain, **kwargs - ) - - assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) - - args, node_kwargs = type_info.canonicalize_arguments( - node.func.type, node.args, node_kwargs, use_signature_ordering=True - ) - - lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) - else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) - - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, - ) - - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, - ) - def _visit_slice_bound( self, slice_bound: Optional[past.Constant], diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 6efee29362..e875709631 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -9,7 +9,7 @@ from typing import ClassVar, List, Optional, Union import gt4py.eve as eve -from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve import Coerced, SymbolName, SymbolRef from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @@ -19,10 +19,6 @@ DimensionKind = common.DimensionKind -# TODO(havogt): -# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. -# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. - @noninstantiable class Node(eve.Node): @@ -97,23 +93,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -class StencilClosure(Node): - domain: FunCall - stencil: Expr - output: Union[SymRef, FunCall] - inputs: List[Union[SymRef, FunCall]] - - @datamodels.validator("output") - def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to 'make_tuple' allowed.") - - @datamodels.validator("inputs") - def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value): - raise ValueError("Only FunCall to 'index' allowed.") - - UNARY_MATH_NUMBER_BUILTINS = {"abs"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -195,18 +174,6 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu } -class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: Coerced[SymbolName] - function_definitions: List[FunctionDefinition] - params: List[Sym] - closures: List[StencilClosure] - implicit_domain: bool = False - - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(BUILTINS) - ] # sorted for serialization stability - - class Stmt(Node): ... @@ -252,8 +219,6 @@ class Program(Node, ValidatedSymbolTableTrait): Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] FunctionDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] -StencilClosure.__hash__ = Node.__hash__ # type: ignore[method-assign] -FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index b4a673772f..29b30beae1 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -216,10 +216,6 @@ def function_definition(self, *args: ir.Node) -> ir.FunctionDefinition: fid, *params, expr = args return ir.FunctionDefinition(id=fid, params=params, expr=expr) - def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: - output, stencil, *inputs, domain = args - return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - def if_stmt(self, cond: ir.Expr, *args): found_else_seperator = False true_branch = [] @@ -249,23 +245,6 @@ def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) - # TODO(havogt): remove after refactoring. - def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: - params = [] - function_definitions = [] - closures = [] - for arg in args: - if isinstance(arg, ir.Sym): - params.append(arg) - elif isinstance(arg, ir.FunctionDefinition): - function_definitions.append(arg) - else: - assert isinstance(arg, ir.StencilClosure) - closures.append(arg) - return ir.FencilDefinition( - id=fid, function_definitions=function_definitions, params=params, closures=closures - ) - def program(self, fid: str, *args: ir.Node) -> ir.Program: params = [] function_definitions = [] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 99287f8a11..a25f99356c 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -248,28 +248,6 @@ def visit_FunctionDefinition(self, node: ir.FunctionDefinition, prec: int) -> li vbody = self._vmerge(params, self._indent(expr)) return self._optimum(hbody, vbody) - def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[str]: - assert prec == 0 - domain = self.visit(node.domain, prec=0) - stencil = self.visit(node.stencil, prec=0) - output = self.visit(node.output, prec=0) - inputs = self.visit(node.inputs, prec=0) - - hinputs = self._hmerge(["("], *self._hinterleave(inputs, ", "), [")"]) - vinputs = self._vmerge(["("], *self._hinterleave(inputs, ",", indent=True), [")"]) - inputs = self._optimum(hinputs, vinputs) - - head = self._hmerge(output, [" ← "]) - foot = self._hmerge(inputs, [" @ "], domain, [";"]) - - h = self._hmerge(head, ["("], stencil, [")"], foot) - v = self._vmerge( - self._hmerge(head, ["("]), - self._indent(self._indent(stencil)), - self._indent(self._hmerge([")"], foot)), - ) - return self._optimum(h, v) - def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: start, end = [node.id + " = temporary("], [");"] args = [] @@ -312,25 +290,6 @@ def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]: head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"] ) - def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: - assert prec == 0 - function_definitions = self.visit(node.function_definitions, prec=0) - closures = self.visit(node.closures, prec=0) - params = self.visit(node.params, prec=0) - - hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) - vparams = self._vmerge( - [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] - ) - params = self._optimum(hparams, vparams) - - function_definitions = self._vmerge(*function_definitions) - closures = self._vmerge(*closures) - - return self._vmerge( - params, self._indent(function_definitions), self._indent(closures), ["}"] - ) - def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 81e9551e5c..12c86680b5 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -258,7 +258,7 @@ def _contains_tuple_dtype_field(arg): return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims) -def _make_fencil_params(fun, args) -> list[Sym]: +def _make_program_params(fun, args) -> list[Sym]: params: list[Sym] = [] param_infos = list(inspect.signature(fun).parameters.values()) @@ -293,18 +293,16 @@ def _make_fencil_params(fun, args) -> list[Sym]: return params -def trace_fencil_definition( - fun: typing.Callable, args: typing.Iterable -) -> itir.FencilDefinition | itir.Program: +def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program: """ - Transform fencil given as a callable into `itir.FencilDefinition` using tracing. + Transform fencil given as a callable into `itir.Program` using tracing. Arguments: - fun: The fencil / callable to trace. + fun: The program / callable to trace. args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type. """ with TracerContext() as _: - params = _make_fencil_params(fun, args) + params = _make_program_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) return itir.Program( diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index aeccb5f26d..d0afc610e7 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -7,10 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator.transforms.pass_manager import ( - ITIRTransform, + GTIRTransform, apply_common_transforms, apply_fieldview_transforms, ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index e71a24127f..b64886f729 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -128,7 +128,7 @@ def apply( flags = flags or cls.flags offset_provider_type = offset_provider_type or {} - if isinstance(node, (ir.Program, ir.FencilDefinition)): + if isinstance(node, ir.Program): within_stencil = False assert within_stencil in [ True, diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 824adfdd8d..4f3fcbfdd5 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -376,7 +376,7 @@ def extract_subexpression( return _NodeReplacer(expr_map).visit(node), extracted, ignored_children -ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr) +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) @dataclasses.dataclass(frozen=True) @@ -413,7 +413,7 @@ def apply( within_stencil: bool | None = None, offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: - is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) + is_program = isinstance(node, itir.Program) if is_program: assert within_stencil is None within_stencil = False diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py deleted file mode 100644 index 4ad91645d4..0000000000 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im - - -class FencilToProgram(eve.NodeTranslator): - @classmethod - def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program: - return cls().visit(node) - - def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs) - return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output) - - def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: - return itir.Program( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - declarations=[], - body=self.visit(node.closures), - implicit_domain=node.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec6f89685a..ec4207d726 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,13 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Optional, Protocol +from typing import Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( - fencil_to_program, fuse_as_fieldop, global_tmps, infer_domain, @@ -32,16 +31,16 @@ from gt4py.next.iterator.type_system.inference import infer -class ITIRTransform(Protocol): +class GTIRTransform(Protocol): def __call__( - self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider + self, _: itir.Program, *, offset_provider: common.OffsetProvider ) -> itir.Program: ... # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, *, offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, @@ -49,10 +48,6 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, @@ -62,9 +57,6 @@ def apply_common_transforms( if offset_provider_type is None: offset_provider_type = common.offset_provider_to_type(offset_provider) - # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this - if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram.apply(ir) assert isinstance(ir, itir.Program) tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") @@ -73,7 +65,7 @@ def apply_common_transforms( ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) # note: this increases the size of the tree @@ -82,7 +74,7 @@ def apply_common_transforms( # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( - ir, # type: ignore[arg-type] # always an itir.Program + ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) @@ -119,7 +111,7 @@ def apply_common_transforms( if extract_temporaries: ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/transforms/program_to_fencil.py b/src/gt4py/next/iterator/transforms/program_to_fencil.py deleted file mode 100644 index 4411dda74f..0000000000 --- a/src/gt4py/next/iterator/transforms/program_to_fencil.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -def program_to_fencil(node: itir.Program) -> itir.FencilDefinition: - assert not node.declarations - closures = [] - for stmt in node.body: - assert isinstance(stmt, itir.SetAt) - assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop") - stencil, domain = stmt.expr.fun.args - inputs = stmt.expr.args - assert all(isinstance(inp, itir.SymRef) for inp in inputs) - closures.append( - itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs) - ) - - return itir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - closures=closures, - ) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py deleted file mode 100644 index 5058a91216..0000000000 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir - - -class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): - """Removes all unused input arguments from a stencil closure.""" - - def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: - if not isinstance(node.stencil, ir.Lambda): - return node - - unused: set[str] = {p.id for p in node.stencil.params} - expr = self.visit(node.stencil.expr, unused=unused, shadowed=set[str]()) - params = [] - inputs = [] - for param, inp in zip(node.stencil.params, node.inputs): - if param.id not in unused: - params.append(param) - inputs.append(inp) - - return ir.StencilClosure( - domain=node.domain, - stencil=ir.Lambda(params=params, expr=expr), - output=node.output, - inputs=inputs, - ) - - def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: - if node.id not in shadowed: - unused.discard(node.id) - return node - - def visit_Lambda(self, node: ir.Lambda, *, unused: set[str], shadowed: set[str]) -> ir.Lambda: - return self.generic_visit( - node, unused=unused, shadowed=shadowed | {p.id for p in node.params} - ) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1765259a81..2903201083 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -69,7 +69,7 @@ def apply( Counter({SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: - inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} + inactive_refs = {str(n.id) for n in itir.Program._NODE_SYMBOLS_} else: inactive_refs = set() diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index ffca6cc7a7..1b980783fa 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -352,7 +352,7 @@ def apply( Preconditions: - All parameters in :class:`itir.Program` and :class:`itir.FencilDefinition` must have a type + All parameters in :class:`itir.Program` must have a type defined, as they are the starting point for type propagation. Design decisions: @@ -401,9 +401,9 @@ def apply( # parts of a program. node = SanitizeTypes().visit(node) - if isinstance(node, (itir.FencilDefinition, itir.Program)): + if isinstance(node, itir.Program): assert all(isinstance(param.type, ts.DataType) for param in node.params), ( - "All parameters in 'itir.Program' and 'itir.FencilDefinition' must have a type " + "All parameters in 'itir.Program' must have a type " "defined, as they are the starting point for type propagation.", ) @@ -460,20 +460,6 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - - function_definitions: dict[str, type_synthesizer.TypeSynthesizer] = {} - for fun_def in node.function_definitions: - function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) - - closures = self.visit(node.closures, ctx=ctx | params | function_definitions) - return it_ts.FencilType(params=params, closures=closures) - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: @@ -532,37 +518,6 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: and target_type.dtype == expr_type.dtype ) - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: - domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) - inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) - output: ts.FieldType = self.visit(node.output, ctx=ctx) - - assert isinstance(domain, it_ts.DomainType) - for output_el in type_info.primitive_constituents(output): - assert isinstance(output_el, ts.FieldType) - - stencil_type_synthesizer = self.visit(node.stencil, ctx=ctx) - stencil_args = [ - type_synthesizer._convert_as_fieldop_input_to_iterator(domain, input_) - for input_ in inputs - ] - stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider_type=self.offset_provider_type - ) - - return it_ts.StencilClosureType( - domain=domain, - stencil=ts.FunctionType( - pos_only_args=stencil_args, - pos_or_kw_args={}, - kw_only_args={}, - returns=stencil_returns, - ), - output=output, - inputs=inputs, - ) - def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: assert ( node.value in self.dimensions diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index edb56f5659..eef8c75d0f 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -43,30 +43,6 @@ class IteratorType(ts.DataType, ts.CallableType): element_type: ts.DataType -@dataclasses.dataclass(frozen=True) -class StencilClosureType(ts.TypeSpec): - domain: DomainType - stencil: ts.FunctionType - output: ts.FieldType | ts.TupleType - inputs: list[ts.FieldType] - - def __post_init__(self): - # local import to avoid importing type_info from a type_specification module - from gt4py.next.type_system import type_info - - for i, el_type in enumerate(type_info.primitive_constituents(self.output)): - assert isinstance( - el_type, ts.FieldType - ), f"All constituent types must be field types, but the {i}-th element is of type '{el_type}'." - - -# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere -@dataclasses.dataclass(frozen=True) -class FencilType(ts.TypeSpec): - params: dict[str, ts.DataType] - closures: list[StencilClosureType] - - @dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 85838d9c76..22326c7e87 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -26,9 +26,7 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -CompilableProgram: TypeAlias = toolchain.CompilableProgram[ - itir.FencilDefinition | itir.Program, arguments.CompileTimeArgs -] +CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index f1649112a7..020b1f55ea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Callable, Final, Optional +from typing import Any, Final, Optional import factory import numpy as np @@ -53,9 +53,6 @@ class GTFNTranslationStep( use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -80,7 +77,7 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting def _process_regular_arguments( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, arg_types: tuple[ts.TypeSpec, ...], offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: @@ -157,7 +154,7 @@ def _process_connectivity_args( def _preprocess_program( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( @@ -167,7 +164,6 @@ def _preprocess_program( # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( @@ -186,7 +182,7 @@ def _preprocess_program( def generate_stencil_source( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: @@ -214,7 +210,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dc0012b041..d5b34fd5b9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -108,7 +108,7 @@ def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index db1242e2a4..5f32eaa2bb 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. gtfn_translation = gtfn.GTFNBackendFactory().executor.translation assert isinstance(gtfn_translation, GTFNTranslationStep) diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py deleted file mode 100644 index 0a8253595e..0000000000 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ /dev/null @@ -1,67 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Any - -from gt4py.eve.codegen import FormatTemplate as as_fmt, TemplatedGenerator -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import apply_common_transforms -from gt4py.next.program_processors import program_formatter - - -class ToLispLike(TemplatedGenerator): - Sym = as_fmt("{id}") - FunCall = as_fmt("({fun} {' '.join(args)})") - Literal = as_fmt("{value}") - OffsetLiteral = as_fmt("{value}") - SymRef = as_fmt("{id}") - StencilClosure = as_fmt( - """( - :domain {domain} - :stencil {stencil} - :output {output} - :inputs {' '.join(inputs)} - ) - """ - ) - FencilDefinition = as_fmt( - """ - ({' '.join(function_definitions)}) - (defen {id}({' '.join(params)}) - {''.join(closures)}) - """ - ) - FunctionDefinition = as_fmt( - """(defun {id}({' '.join(params)}) - {expr} - ) - -""" - ) - Lambda = as_fmt( - """(lambda ({' '.join(params)}) - {expr} - )""" - ) - - @classmethod - def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) - generated_code = super().apply(transformed, **kwargs) - try: - from yasi import indent_code - - indented = indent_code(generated_code, "--dialect lisp") - return "".join(indented["indented_code"]) - except ImportError: - return generated_code - - -@program_formatter.program_formatter -def format_lisp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return ToLispLike.apply(program, **kwargs) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index f14ac5653f..cbf9fd1978 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_itir_and_check(program: itir.Program, *args: Any, **kwargs: Any) -> str: pretty = pretty_printer.pformat(program) parsed = pretty_parser.pparse(pretty) assert parsed == program diff --git a/src/gt4py/next/program_processors/program_formatter.py b/src/gt4py/next/program_processors/program_formatter.py index f77e7f32ee..321c09668c 100644 --- a/src/gt4py/next/program_processors/program_formatter.py +++ b/src/gt4py/next/program_processors/program_formatter.py @@ -10,7 +10,7 @@ Interface for program processors. Program processors are functions which operate on a program paired with the input -arguments for the program. Programs are represented by an ``iterator.ir.itir.FencilDefinition`` +arguments for the program. Programs are represented by an ``iterator.ir.Program`` node. Program processors that execute the program with the given arguments (possibly by generating code along the way) are program executors. Those that generate any kind of string based on the program and (optionally) input values are program formatters. @@ -30,14 +30,14 @@ class ProgramFormatter: @abc.abstractmethod - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: ... + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: ... @dataclasses.dataclass(frozen=True) class WrappedProgramFormatter(ProgramFormatter): formatter: Callable[..., str] - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: return self.formatter(program, *args, **kwargs) @@ -47,7 +47,7 @@ def program_formatter(func: Callable[..., str]) -> ProgramFormatter: Examples: >>> @program_formatter - ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: + ... def format_foo(fencil: itir.Program, *args, **kwargs) -> str: ... '''A very useless fencil formatter.''' ... return "foo" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 40d44f5ab0..a38a50d886 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -72,7 +72,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data assert isinstance(program, itir.Program) sdfg = self.generate_sdfg( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 55f479c665..c0a9be9168 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -125,7 +125,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: Generates a unique hash string for a stencil source program representing the program, sorted offset_provider, and column_axis. """ - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 25eda5a2ed..32c3f7a360 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -90,11 +90,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, - transforms: itir_transforms.ITIRTransform, + transforms: itir_transforms.GTIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -197,7 +197,7 @@ class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgr debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None - transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -265,10 +265,10 @@ def decorated_fencil( gtir = next_backend.Backend( name="roundtrip_gtir", - executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later... + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # don't understand why mypy complains allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + past_to_itir=past_to_itir.past_to_gtir_factory(), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), field_view_op_to_prog=foast_to_past.operator_to_program_factory( foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 45bf7428a6..9e80dba53b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -21,7 +21,7 @@ ) -def test_program_itir_regression(cartesian_case): +def test_program_gtir_regression(cartesian_case): @gtx.field_operator(backend=None) def testee_op(a: cases.IField) -> cases.IField: return a @@ -30,8 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.Program) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.Program) + assert isinstance(testee.gtir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).gtir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 66c56c4827..7d2eec772c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -107,12 +107,12 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - itir_with_tmp = apply_common_transforms( - testee.itir, + gtir_with_tmp = apply_common_transforms( + testee.gtir, extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) params = ["num_vertices", "num_edges", "num_cells"] for param in params: - assert any([param == str(p) for p in itir_with_tmp.params]) + assert any([param == str(p) for p in gtir_with_tmp.params]) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8f6d5787d3..03662f8dcc 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -58,7 +58,6 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation - (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), ], diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py deleted file mode 100644 index c102df9d57..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ /dev/null @@ -1,598 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the -# lowering can (and often does) make all of them fail. Once we have embedded field view we want to -# switch to executing the different cases here; once with a regular backend (i.e. including -# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering -# should be correct. - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation -from gt4py.next.iterator.type_system import type_specifications as it_ts - - -IDim = gtx.Dimension("IDim") -Edge = gtx.Dimension("Edge") -Vertex = gtx.Dimension("Vertex") -Cell = gtx.Dimension("Cell") -V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. - - -def debug_itir(tree): - """Compare tree snippets while debugging.""" - from devtools import debug - - from gt4py.eve.codegen import format_python_source - from gt4py.next.program_processors import EmbeddedDSL - - debug(format_python_source(EmbeddedDSL.apply(tree))) - - -def test_copy(): - def copy_field(inp: gtx.Field[[TDim], float64]): - return inp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - assert lowered.id == "copy_field" - assert lowered.expr == im.ref("inp") - - -def test_scalar_arg(): - def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: - return alpha * bar - - parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")( - "alpha", "bar" - ) # no difference to non-scalar arg - - assert lowered.expr == reference - - -def test_multicopy(): - def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1, inp2 - - parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple("inp1", "inp2") - - assert lowered.expr == reference - - -def test_arithmetic(): - def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1 + inp2 - - parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") - - assert lowered.expr == reference - - -def test_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_negative_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[-1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_temp_assignment(): - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp = inp - inp = tmp - tmp2 = inp - return tmp2 - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")( - im.let( - ssa.unique_name("inp", 0), - ssa.unique_name("tmp", 0), - )( - im.let( - ssa.unique_name("tmp2", 0), - ssa.unique_name("inp", 0), - )(ssa.unique_name("tmp2", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_unary_ops(): - def unary(inp: gtx.Field[[TDim], float64]): - tmp = +inp - tmp = -tmp - return tmp - - parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("0", "float64")), "inp" - ), - )( - im.let( - ssa.unique_name("tmp", 1), - im.promote_to_lifted_stencil("minus")( - im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) - ), - )(ssa.unique_name("tmp", 1)) - ) - - assert lowered.expr == reference - - -@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) -def test_unary_op_type_conversion(var, var_type): - def unary_float(): - return float(-1) - - def unary_bool(): - return bool(-1) - - fun = unary_bool if var_type == "bool" else unary_float - parsed = FieldOperatorParser.apply_to_function(fun) - lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_const_iterator(im.literal(var, var_type)) - - assert lowered.expr == reference - - -def test_unpacking(): - """Unpacking assigns should get separated.""" - - def unpacking( - inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] - ) -> gtx.Field[[TDim], float64]: - tmp1, tmp2 = inp1, inp2 # noqa - return tmp1 - - parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("inp1", "inp2") - tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") - tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") - - reference = im.let("__tuple_tmp_0", tuple_expr)( - im.let( - ssa.unique_name("tmp1", 0), - tuple_access_0, - )( - im.let( - ssa.unique_name("tmp2", 0), - tuple_access_1, - )(ssa.unique_name("tmp1", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_annotated_assignment(): - pytest.xfail("Annotated assignments are not properly supported at the moment.") - - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp: gtx.Field[[TDim], float64] = inp - return tmp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_call(): - # create something that appears to the lowering like a field operator. - # we could also create an actual field operator, but we want to avoid - # using such heavy constructs for testing the lowering. - field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) - identity = SimpleNamespace( - __gt_type__=lambda: ts_ffront.FieldOperatorType( - definition=ts.FunctionType( - pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type - ) - ) - ) - - def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: - return identity(inp) - - parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.call("identity")("inp") - - assert lowered.expr == reference - - -def test_temp_tuple(): - """Returning a temp tuple should work.""" - - def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("a", "b") - reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_unary_not(): - def unary_not(cond: gtx.Field[[TDim], "bool"]): - return not cond - - parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("not_")("cond") - - assert lowered.expr == reference - - -def test_binary_plus(): - def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a + b - - parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("a", "b") - - assert lowered.expr == reference - - -def test_add_scalar_literal_to_field(): - def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: - return 2.0 + a - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" - ) - - assert lowered.expr == reference - - -def test_add_scalar_literals(): - def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: - tmp = int32(1) + int32("1") - return a + tmp - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - ), - )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) - - assert lowered.expr == reference - - -def test_binary_mult(): - def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a * b - - parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")("a", "b") - - assert lowered.expr == reference - - -def test_binary_minus(): - def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a - b - - parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("minus")("a", "b") - - assert lowered.expr == reference - - -def test_binary_div(): - def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a / b - - parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("divides")("a", "b") - - assert lowered.expr == reference - - -def test_binary_and(): - def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a & b - - parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")("a", "b") - - assert lowered.expr == reference - - -def test_scalar_and(): - def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: - return a & False - - parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - "a", im.promote_to_const_iterator(im.literal("False", "bool")) - ) - - assert lowered.expr == reference - - -def test_binary_or(): - def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a | b - - parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("or_")("a", "b") - - assert lowered.expr == reference - - -def test_compare_scalars(): - def comp_scalars() -> bool: - return 3 > 4 - - parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")( - im.promote_to_const_iterator(im.literal("3", "int32")), - im.promote_to_const_iterator(im.literal("4", "int32")), - ) - - assert lowered.expr == reference - - -def test_compare_gt(): - def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a > b - - parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")("a", "b") - - assert lowered.expr == reference - - -def test_compare_lt(): - def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a < b - - parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("less")("a", "b") - - assert lowered.expr == reference - - -def test_compare_eq(): - def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): - return a == b - - parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("eq")("a", "b") - - assert lowered.expr == reference - - -def test_compare_chain(): - def compare_chain( - a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] - ) -> gtx.Field[[IDim], bool]: - return a > b > c - - parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - im.promote_to_lifted_stencil("greater")("a", "b"), - im.promote_to_lifted_stencil("greater")("b", "c"), - ) - - assert lowered.expr == reference - - -def test_reduction_lowering_simple(): - def reduction(edge_f: gtx.Field[[Edge], float64]): - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), - ) - ) - )(im.lifted_neighbors("V2E", "edge_f")) - - assert lowered.expr == reference - - -def test_reduction_lowering_expr(): - def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): - e1_nbh = e1(V2E) - return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( - im.promote_to_lifted_stencil("make_const_list")( - im.promote_to_const_iterator(im.literal("1.1", "float64")) - ), - im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), - ) - - reference = im.let( - ssa.unique_name("e1_nbh", 0), - im.lifted_neighbors("V2E", "e1"), - )( - im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref( - im.promote_to_const_iterator(im.literal(value="0", typename="float64")) - ), - ) - ) - )(mapped) - ) - - assert lowered.expr == reference - - -def test_builtin_int_constructors(): - def int_constrs() -> tuple[int32, int32, int64, int32, int64]: - return 1, int32(1), int64(1), int32("1"), int64("1") - - parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - ) - - assert lowered.expr == reference - - -def test_builtin_float_constructors(): - def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: - return ( - 0.1, - float(0.1), - float32(0.1), - float64(0.1), - float(".1"), - float32(".1"), - float64(".1"), - ) - - parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - ) - - assert lowered.expr == reference - - -def test_builtin_bool_constructors(): - def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: - return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") - - parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), - ) - - assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index a6231c22a7..c813285bd0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -46,7 +46,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -93,7 +92,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -149,9 +147,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2[1:])) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail( @@ -166,9 +162,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2)) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail @@ -194,7 +188,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): ProgramParser.apply_to_function(invalid_call_sig_program_def), function_definitions=[], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) assert exc_info.match("Invalid call to 'identity'") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py deleted file mode 100644 index fefd3c653b..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ /dev/null @@ -1,214 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import re - -import pytest - -import gt4py.eve as eve -import gt4py.next as gtx -from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import errors -from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir -from gt4py.next.type_system import type_specifications as ts - -from next_tests.past_common_fixtures import ( - IDim, - copy_program_def, - copy_restrict_program_def, - float64, - identity_def, - invalid_call_sig_program_def, -) - - -@pytest.fixture -def itir_identity_fundef(): - return itir.FunctionDefinition( - id="identity", - params=[itir.Sym(id="x")], - expr=itir.FunCall(fun=itir.SymRef(id="deref"), args=[itir.SymRef(id="x")]), - ) - - -def test_copy_lowering(copy_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), - ], - ) - ], - ), - stencil=P( - itir.Lambda, - params=[P(itir.Sym, id=eve.SymbolName("__stencil_arg0"))], - expr=P( - itir.FunCall, - fun=P( - itir.Lambda, - params=[P(itir.Sym)], - expr=P(itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("deref"))), - ), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("identity")), - args=[P(itir.SymRef, id=eve.SymbolRef("__stencil_arg0"))], - ) - ], - ), - ), - inputs=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], - output=P(itir.SymRef, id=eve.SymbolRef("out")), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_restrict_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - ], - ) - ], - ), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_restrict_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2[1:])) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail( - reason="slicing is only allowed if all fields are sliced in the same way." -) # see ADR 10 -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2)) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail -def test_inout_prohibited(identity_def): - identity = gtx.field_operator(identity_def) - - def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): - identity(inout_field, out=inout_field) - - with pytest.raises( - ValueError, match=(r"Call to function with field as input and output not allowed.") - ): - ProgramLowering.apply( - ProgramParser.apply_to_function(inout_field_program), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - -def test_invalid_call_sig_program(invalid_call_sig_program_def): - with pytest.raises(errors.DSLError) as exc_info: - ProgramLowering.apply( - ProgramParser.apply_to_function(invalid_call_sig_program_def), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - assert exc_info.match("Invalid call to 'identity'") - # TODO(tehrengruber): re-enable again when call signature check doesn't return - # immediately after missing `out` argument - # assert ( - # re.search( - # "Function takes 1 arguments, but 2 were given.", exc_info.value.__cause__.args[0] - # ) - # is not None - # ) - assert ( - re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) - is not None - ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 817c06e8f0..2492fc446d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,21 +8,24 @@ # TODO(SF-N): test scan operator -import pytest +from typing import Iterable, Literal, Optional, Union + import numpy as np -from typing import Iterable, Optional, Literal, Union +import pytest from gt4py import eve -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next import constructors +from gt4py.next import common, constructors, utils +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.ir_utils import domain_utils -from gt4py.next.common import Dimension -from gt4py.next import common -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next import utils +from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py deleted file mode 100644 index 407ccad924..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py +++ /dev/null @@ -1,68 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs - - -def test_simple(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected - - -def test_shadowing(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index b1ba4ccf22..03b8e3bc15 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -36,6 +36,7 @@ from . import pytestmark + dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace_fieldview") From ae6296546d91f41e40451403c3560b1744d467cc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 21:15:55 +0100 Subject: [PATCH 39/43] feat[next]: Inline dynamic shifts (#1738) Dynamic shifts are not supported in the domain inference. In order to make them work nonetheless this PR aggressively inlines all arguments to `as_fieldop` until they contain only references to `itir.Program` params. Additionally the domain inference is extended to tolerate such `as_fieldop` by introducing a special domain marker that signifies a domain is unknown. --------- Co-authored-by: Hannes Vogt Co-authored-by: Edoardo Paone --- .../iterator/transforms/fuse_as_fieldop.py | 209 ++++++++------ .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/infer_domain.py | 272 +++++++++++------- .../transforms/inline_dynamic_shifts.py | 73 +++++ .../next/iterator/transforms/pass_manager.py | 7 + tests/next_tests/definitions.py | 1 - .../test_inline_dynamic_shifts.py | 48 ++++ .../transforms_tests/test_domain_inference.py | 115 +++++--- 8 files changed, 492 insertions(+), 237 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 9076bf2d3f..e8a221b814 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr) + type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -68,6 +68,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr): return isinstance(expr, itir.Literal) +def _inline_as_fieldop_arg( + arg: itir.Expr, *, uids: eve_utils.UIDGenerator +) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def fuse_as_fieldop( + expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator +) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + args: list[itir.Expr] = expr.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True): + if eligible: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif _is_tuple_expr_of_literals(arg): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + # just a safety check if typing information is available + if arg.type and not isinstance(arg.type, ts.DeferredType): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, it_ts.ListType) + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + + type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) + + return new_node + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -98,38 +199,6 @@ class FuseAsFieldOp(eve.NodeTranslator): uids: eve_utils.UIDGenerator - def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: - assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) - - stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args: list[itir.Expr] = arg.args - extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg - - stencil_params: list[itir.Sym] = [] - stencil_body: itir.Expr = stencil.expr - - for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): - if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) - extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. - stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( - stencil_body - ) - else: - # a scalar expression, a previously not inlined `as_fieldop` call or an opaque - # expression e.g. containing a tuple - stencil_params.append(inner_param) - new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") - extracted_args[new_outer_stencil_param] = inner_arg - - return im.lift(im.lambda_(*stencil_params)(stencil_body))( - *extracted_args.keys() - ), extracted_args - @classmethod def apply( cls, @@ -158,72 +227,26 @@ def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): stencil: itir.Lambda = node.fun.args[0] - domain = node.fun.args[1] if len(node.fun.args) > 1 else None - - shifts = trace_shifts.trace_stencil(stencil) - args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil) - new_args: dict[str, itir.Expr] = {} - new_stencil_body: itir.Expr = stencil.expr - - for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + eligible_args = [] + for arg, arg_shifts in zip(args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = _is_tuple_expr_of_literals(arg) or ( - isinstance(arg, itir.FunCall) - and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - or cpm.is_call_to(arg, "if_") + eligible_args.append( + _is_tuple_expr_of_literals(arg) + or ( + isinstance(arg, itir.FunCall) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - if should_inline: - if cpm.is_applied_as_fieldop(arg): - pass - elif cpm.is_call_to(arg, "if_"): - # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type - arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ - elif _is_tuple_expr_of_literals(arg): - arg = im.op_as_fieldop(im.lambda_()(arg))() - else: - raise NotImplementedError() - - inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) - - new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) - - new_args = _merge_arguments(new_args, extracted_args) - else: - assert not isinstance(dtype, it_ts.ListType) - new_param: str - if isinstance( - arg, itir.SymRef - ): # use name from outer scope (optional, just to get a nice IR) - new_param = arg.id - new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) - else: - new_param = stencil_param.id - new_args = _merge_arguments(new_args, {new_param: arg}) - - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) - - # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node - ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True - ) - new_node = inline_lifts.InlineLifts().visit(new_node) - - type_inference.copy_type(from_=node, to=new_node) - return new_node + return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a6d39883e3..334fb330d7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -74,7 +74,7 @@ def _transform_by_pattern( # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - domain = tmp_expr.annex.domain + domain: infer_domain.DomainAccess = tmp_expr.annex.domain # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are @@ -186,7 +186,7 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - program = infer_domain.infer_program(program, offset_provider) + program = infer_domain.infer_program(program, offset_provider=offset_provider) program = type_inference.infer( program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6852b47a7a..f26d3f9ec2 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,10 +10,10 @@ import itertools import typing -from typing import Callable, Optional, TypeAlias from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -25,8 +25,35 @@ from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAccessDescriptor(eve.StrEnum): + """ + Descriptor for domains that could not be inferred. + """ + + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor + # `KNOWN`, but since we don't need it, it wasn't added. + + #: The access is unknown because of a dynamic shift.whose extent is not known. + #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` + UNKNOWN = "unknown" + #: The domain is never accessed. + #: E.g.: `{in_field1, in_field2}[0]` + NEVER = "never" + + +NonTupleDomainAccess: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor +#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning +#: a tuple since other occurrences for tuples are removed before domain inference. This is +#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just +#: fine to a tuple of a vertex and an edge domain. +DomainAccess: TypeAlias = NonTupleDomainAccess | tuple["DomainAccess", ...] +AccessedDomains: TypeAlias = dict[str, DomainAccess] + + +class InferenceOptions(typing.TypedDict): + offset_provider: common.OffsetProvider + symbolic_domain_sizes: Optional[dict[str, str]] + allow_uninferred: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -57,43 +84,58 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. -def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None, -) -> domain_utils.SymbolicDomain | None: - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] +def _domain_union( + *domains: domain_utils.SymbolicDomain | DomainAccessDescriptor, +) -> domain_utils.SymbolicDomain | DomainAccessDescriptor: + if any(d == DomainAccessDescriptor.UNKNOWN for d in domains): + return DomainAccessDescriptor.UNKNOWN + + filtered_domains: list[domain_utils.SymbolicDomain] = [ + d # type: ignore[misc] # domain can never be unknown as these cases are filtered above + for d in domains + if d != DomainAccessDescriptor.NEVER + ] if len(filtered_domains) == 0: - return None + return DomainAccessDescriptor.NEVER return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure( + d1: DomainAccess, d2: DomainAccess +) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. If one of the arguments is a tuple the other one will be promoted to a tuple of same structure - unless it already is a tuple. Missing values are replaced by None, meaning no domain is - specified. + unless it already is a tuple. Missing values are filled by :ref:`DomainAccessDescriptor.NEVER`. >>> domain = im.domain(common.GridType.CARTESIAN, {}) >>> _canonicalize_domain_structure((domain,), (domain, domain)) == ( - ... (domain, None), + ... (domain, DomainAccessDescriptor.NEVER), ... (domain, domain), ... ) True - >>> _canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + >>> _canonicalize_domain_structure( + ... (domain, DomainAccessDescriptor.NEVER), DomainAccessDescriptor.NEVER + ... ) == ( + ... (domain, DomainAccessDescriptor.NEVER), + ... (DomainAccessDescriptor.NEVER, DomainAccessDescriptor.NEVER), + ... ) True """ - if d1 is None and isinstance(d2, tuple): - return _canonicalize_domain_structure((None,) * len(d2), d2) - if d2 is None and isinstance(d1, tuple): - return _canonicalize_domain_structure(d1, (None,) * len(d1)) + if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): + return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): + return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): return tuple( zip( *( _canonicalize_domain_structure(el1, el2) - for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + for el1, el2 in itertools.zip_longest( + d1, d2, fillvalue=DomainAccessDescriptor.NEVER + ) ) ) ) # type: ignore[return-value] # mypy not smart enough @@ -101,16 +143,16 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): original_domain, domain = _canonicalize_domain_structure( - original_domains.get(key, None), domain + original_domains.get(key, DomainAccessDescriptor.NEVER), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union)(original_domain, domain) return new_domains @@ -118,44 +160,52 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain, + target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} +) -> dict[str, NonTupleDomainAccess]: + accessed_domains: dict[str, NonTupleDomainAccess] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): + # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use + # special `UNKNOWN` marker for them until we have implemented a proper solution. + if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): + accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN + continue + new_domains = [ domain_utils.SymbolicDomain.translate( target_domain, shift, offset_provider, symbolic_domain_sizes ) + if not isinstance(target_domain, DomainAccessDescriptor) + else target_domain for shift in shifts_list ] - # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( - accessed_domains.get(in_field_id, None), *new_domains + accessed_domains[in_field_id] = _domain_union( + accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) - return typing.cast(ACCESSED_DOMAINS, accessed_domains) + return accessed_domains def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + allow_uninferred: bool, +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is None: - raise ValueError("'target_domain' cannot be 'None'.") + if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): - target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) - if not isinstance(target_domain, domain_utils.SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") + target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough + assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -177,22 +227,29 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomainAccess] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s - accessed_domains: ACCESSED_DOMAINS = {} + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes + in_field, + inputs_accessed_domains[in_field_id], + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + if not isinstance(target_domain, DomainAccessDescriptor): + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + else: + target_domain_expr = None transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { @@ -206,17 +263,15 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + input_domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy - transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes - ) - let_params = {param_sym.id for param_sym in let_expr.fun.params} + + transformed_calls_expr, accessed_domains = infer_expr(let_expr.fun.expr, input_domain, **kwargs) + accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains ) @@ -227,10 +282,9 @@ def _infer_let( arg, accessed_domains_let_args.get( param.id, - None, + DomainAccessDescriptor.NEVER, ), - offset_provider, - symbolic_domain_sizes, + **kwargs, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -247,13 +301,12 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -261,13 +314,12 @@ def _infer_make_tuple( # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; domain = (domain,) * len(expr.args) assert len(expr.args) >= len(domain) - # There may be less domains than tuple args, pad the domain with `None` in that case. - # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` - domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + # There may be fewer domains than tuple args, pad the domain with `NEVER` + # in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(DomainAccessDescriptor.NEVER for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain[i], offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) @@ -276,19 +328,18 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) - tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr( - tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + tuple_domain = tuple( + DomainAccessDescriptor.NEVER if i != idx else domain for i in range(idx + 1) ) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, **kwargs) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -297,18 +348,15 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain, offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -317,24 +365,23 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_as_fieldop(expr, domain, **kwargs) elif cpm.is_let(expr): - return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_let(expr, domain, **kwargs) elif cpm.is_call_to(expr, "make_tuple"): - return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_make_tuple(expr, domain, **kwargs) elif cpm.is_call_to(expr, "tuple_get"): - return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): - return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_if(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -347,10 +394,12 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + allow_uninferred: bool = False, +) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -362,30 +411,35 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. + - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. + because of a dynamic shift) or never accessed. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to domain they are accessed at. """ - # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) + expr, accessed_domains = _infer_expr( + expr, + domain, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) expr.annex.domain = domain + return expr, accessed_domains def _infer_stmt( stmt: itir.Stmt, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], + **kwargs: Unpack[InferenceOptions], ): if isinstance(stmt, itir.SetAt): - transformed_call, _unused_domain = infer_expr( - stmt.expr, - domain_utils.SymbolicDomain.from_expr(stmt.domain), - offset_provider, - symbolic_domain_sizes, + transformed_call, _ = infer_expr( + stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs ) + return itir.SetAt( expr=transformed_call, domain=stmt.domain, @@ -394,20 +448,18 @@ def _infer_stmt( elif isinstance(stmt, itir.IfStmt): return itir.IfStmt( cond=stmt.cond, - true_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch - ], - false_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch - ], + true_branch=[_infer_stmt(c, **kwargs) for c in stmt.true_branch], + false_branch=[_infer_stmt(c, **kwargs) for c in stmt.false_branch], ) raise ValueError(f"Unsupported stmt: {stmt}") def infer_program( program: itir.Program, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -423,5 +475,13 @@ def infer_program( function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], + body=[ + _infer_stmt( + stmt, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) + for stmt in program.body + ], ) diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py new file mode 100644 index 0000000000..0af9d9dab9 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs + + +def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: + if not cpm.is_applied_as_fieldop(node): + return None + params_shifts = trace_shifts.trace_stencil( + node.fun.args[0], # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + num_args=len(node.args), + save_to_annex=True, + ) + dynamic_shifts = [ + any(trace_shifts.Sentinel.VALUE in shifts for shifts in param_shifts) + for param_shifts in params_shifts + ] + return dynamic_shifts + + +@dataclasses.dataclass +class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + uids: eve_utils.UIDGenerator + + @classmethod + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + node = self.generic_visit(node, **kwargs) + + if cpm.is_let(node) and ( + dynamic_shift_args := _dynamic_shift_args(let_body := node.fun.expr) # type: ignore[attr-defined] # ensured by is_let + ): + inline_let_params = {p.id: False for p in node.fun.params} # type: ignore[attr-defined] # ensured by is_let + + for inp, is_dynamic_shift_arg in zip(let_body.args, dynamic_shift_args, strict=True): + for ref in collect_symbol_refs(inp): + if ref in inline_let_params and is_dynamic_shift_arg: + inline_let_params[ref] = True + + if any(inline_let_params): + node = inline_lambdas.inline_lambda( + node, eligible_params=list(inline_let_params.values()) + ) + + if dynamic_shift_args := _dynamic_shift_args(node): + assert len(node.fun.args) in [1, 2] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop in _dynamic_shift_args + fuse_args = [ + not isinstance(inp, itir.SymRef) and dynamic_shift_arg + for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) + ] + if any(fuse_args): + return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec4207d726..d967c8fbb8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -15,6 +15,7 @@ fuse_as_fieldop, global_tmps, infer_domain, + inline_dynamic_shifts, inline_fundefs, inline_lifts, ) @@ -73,6 +74,9 @@ def apply_common_transforms( ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -158,5 +162,8 @@ def apply_fieldview_transforms( ir = CollapseTuple.apply( ir, offset_provider_type=common.offset_provider_to_type(offset_provider) ) # type: ignore[assignment] # type is still `itir.Program` + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index d7413f32d7..bed6e89a52 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -130,7 +130,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py new file mode 100644 index 0000000000..ff7a761c5a --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_dynamic_shifts +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_inline_dynamic_shift_as_fieldop_arg(): + testee = im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + im.as_fieldop("deref")("inp"), "offset_field" + ) + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected + + +def test_inline_dynamic_shift_let_var(): + testee = im.let("tmp", im.as_fieldop("deref")("inp"))( + im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + "tmp", "offset_field" + ) + ) + + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 2492fc446d..779ab738cb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -76,7 +76,7 @@ def setup_test_as_fieldop( def run_test_program( testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_domain.infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider=offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected @@ -89,12 +89,14 @@ def run_test_expr( expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ): actual_call, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, - symbolic_domain_sizes, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -104,10 +106,8 @@ def run_test_expr( def canonicalize_domain(d): if isinstance(d, dict): return im.domain(grid_type, d) - elif isinstance(d, itir.FunCall): + elif isinstance(d, (itir.FunCall, infer_domain.DomainAccessDescriptor)): return d - elif d is None: - return None raise AssertionError() expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} @@ -128,10 +128,12 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( - domains: infer_domain.ACCESSED_DOMAINS, -) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: domain_utils.SymbolicDomain | None): - if domain is None: + domains: infer_domain.AccessedDomains, +) -> infer_domain.AccessedDomains: + def fold_domain( + domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], + ): + if isinstance(domain, infer_domain.DomainAccessDescriptor): return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -154,7 +156,7 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( - shift_list, offset_provider + shift_list, offset_provider=offset_provider ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -340,7 +342,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -384,7 +386,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -397,7 +399,10 @@ def test_unused_input(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + expected_domains = { + "in_field1": {IDim: (0, 11)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } testee, expected = setup_test_as_fieldop( stencil, domain, @@ -409,7 +414,7 @@ def test_let_unused_field(offset_provider): testee = im.let("a", "c")("b") domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.let("a", "c")("b") - expected_domains = {"b": {IDim: (0, 11)}, "c": None} + expected_domains = {"b": {IDim: (0, 11)}, "c": infer_domain.DomainAccessDescriptor.NEVER} run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -522,7 +527,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -579,7 +584,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -803,7 +808,7 @@ def test_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -815,13 +820,13 @@ def test_tuple_get_1_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) expected_domains = { - "a": None, + "a": infer_domain.DomainAccessDescriptor.NEVER, "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), - "c": None, + "c": infer_domain.DomainAccessDescriptor.NEVER, } actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -833,7 +838,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) - expected_domains = {"a": None, "b": domain1, "c": domain2} + expected_domains = {"a": infer_domain.DomainAccessDescriptor.NEVER, "b": domain1, "c": domain2} actual, actual_domains = infer_domain.infer_expr( testee, @@ -841,7 +846,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -852,14 +857,18 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) - expected_domains = {"b": None, "c": None, "d": (None, domain)} + expected_domains = { + "b": infer_domain.DomainAccessDescriptor.NEVER, + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": (infer_domain.DomainAccessDescriptor.NEVER, domain), + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -870,12 +879,16 @@ def test_tuple_get_let_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) - expected_domains = {"c": None, "d": domain, "b": None} + expected_domains = { + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": domain, + "b": infer_domain.DomainAccessDescriptor.NEVER, + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -903,7 +916,7 @@ def test_nested_make_tuple(offset_provider): ), domain_utils.SymbolicDomain.from_expr(domain3), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -914,10 +927,10 @@ def test_tuple_get_1(offset_provider): testee = im.tuple_get(1, im.ref("a")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.ref("a")) - expected_domains = {"a": (None, domain)} + expected_domains = {"a": (infer_domain.DomainAccessDescriptor.NEVER, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -937,7 +950,7 @@ def test_domain_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -953,7 +966,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -973,7 +986,7 @@ def test_make_tuple_2tuple_get(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -990,7 +1003,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -1004,7 +1017,7 @@ def test_arithmetic_builtin(offset_provider): expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) @@ -1048,3 +1061,35 @@ def test_symbolic_domain_sizes(unstructured_offset_provider): unstructured_offset_provider, symbolic_domain_sizes, ) + + +def test_unknown_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": infer_domain.DomainAccessDescriptor.UNKNOWN, + "in_field2": {IDim: (0, 10)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain_tuple(offset_provider): + testee = im.tuple_get(0, im.make_tuple("in_field1", "in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + run_test_expr(testee, testee, domain, expected_domains, offset_provider) From 29b6af23c15955910f413ed12e5d1a463e7b5b4b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 9 Dec 2024 16:44:28 +0100 Subject: [PATCH 40/43] build: fix min version of filelock (#1777) ... and fix linting after ruff update. --- .pre-commit-config.yaml | 10 ++-- constraints.txt | 48 +++++++++---------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 48 +++++++++---------- src/gt4py/__init__.py | 2 +- src/gt4py/cartesian/__init__.py | 4 +- src/gt4py/cartesian/backend/__init__.py | 2 +- src/gt4py/cartesian/cli.py | 2 +- src/gt4py/cartesian/frontend/__init__.py | 2 +- src/gt4py/cartesian/gtscript.py | 6 +-- src/gt4py/cartesian/testing/__init__.py | 2 +- src/gt4py/cartesian/utils/__init__.py | 2 +- src/gt4py/cartesian/utils/base.py | 6 +-- src/gt4py/eve/__init__.py | 2 +- src/gt4py/eve/datamodels/validators.py | 2 +- src/gt4py/next/errors/__init__.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/runtime.py | 2 +- .../next/iterator/transforms/__init__.py | 2 +- .../iterator/transforms/fuse_as_fieldop.py | 6 ++- .../transformations/__init__.py | 14 +++--- src/gt4py/storage/__init__.py | 6 +-- 24 files changed, 88 insertions(+), 90 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e1870c67f..e383112310 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.4 + rev: v0.8.2 ##[[[end]]] hooks: # Run the linter. @@ -96,7 +96,7 @@ repos: - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.31.0.1 + - cmake==3.31.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 @@ -108,9 +108,9 @@ repos: - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 - - mako==1.3.6 - - nanobind==2.2.0 - - ninja==1.11.1.1 + - mako==1.3.8 + - nanobind==2.4.0 + - ninja==1.11.1.2 - numpy==1.24.4 - packaging==24.2 - pybind11==2.13.6 diff --git a/constraints.txt b/constraints.txt index f039fa2125..fbdfb6e267 100644 --- a/constraints.txt +++ b/constraints.txt @@ -23,9 +23,9 @@ certifi==2024.8.30 # via requests cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests -clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) +clang-format==19.1.4 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.0.1 # via gt4py (pyproject.toml) +cmake==3.31.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==1.0.0 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.8 # via ipykernel +debugpy==1.8.9 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) @@ -47,11 +47,11 @@ exceptiongroup==1.2.2 # via hypothesis, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==33.0.0 # via factory-boy -fastjsonschema==2.20.0 # via nbformat +faker==33.1.0 # via factory-boy +fastjsonschema==2.21.1 # via nbformat filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.0 # via matplotlib -fparser==0.1.4 # via dace +fonttools==4.55.2 # via matplotlib +fparser==0.2.0 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach @@ -75,7 +75,7 @@ jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbformat jupytext==1.16.4 # via -r requirements-dev.in kiwisolver==1.4.7 # via matplotlib lark==1.2.2 # via gt4py (pyproject.toml) -mako==1.3.6 # via gt4py (pyproject.toml) +mako==1.3.8 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via jinja2, mako matplotlib==3.7.5 # via -r requirements-dev.in @@ -85,13 +85,13 @@ mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy mypy==1.13.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==2.2.0 # via gt4py (pyproject.toml) +nanobind==2.4.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.10.4 # via jupytext, nbclient, nbmake nbmake==1.5.4 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace, tach -ninja==1.11.1.1 # via gt4py (pyproject.toml) +ninja==1.11.1.2 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff @@ -102,7 +102,7 @@ pexpect==4.9.0 # via ipython pickleshare==0.7.5 # via ipython pillow==10.4.0 # via matplotlib pip-tools==7.4.1 # via -r requirements-dev.in -pipdeptree==2.23.4 # via -r requirements-dev.in +pipdeptree==2.24.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==4.3.6 # via black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via pytest, tox @@ -113,15 +113,15 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.10.0 # via bump-my-version, pydantic-settings -pydantic-core==2.27.0 # via pydantic +pydantic==2.10.3 # via bump-my-version, pydantic-settings +pydantic-core==2.27.1 # via pydantic pydantic-settings==2.6.1 # via bump-my-version -pydot==3.0.2 # via tach +pydot==3.0.3 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via matplotlib, pydot pyproject-api==1.8.0 # via tox pyproject-hooks==1.2.0 # via build, pip-tools -pytest==8.3.3 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist +pytest==8.3.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==5.0.0 # via -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -r requirements-dev.in @@ -137,12 +137,12 @@ questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach -rich-click==1.8.4 # via bump-my-version +rich-click==1.8.5 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.4 # via -r requirements-dev.in +ruff==0.8.2 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser -six==1.16.0 # via asttokens, astunparse, python-dateutil +six==1.17.0 # via asttokens, astunparse, python-dateutil smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis @@ -159,21 +159,21 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.4 # via -r requirements-dev.in -tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tach==0.16.5 # via -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version toolz==1.0.0 # via cytoolz -tornado==6.4.1 # via ipykernel, jupyter-client +tornado==6.4.2 # via ipykernel, jupyter-client tox==4.23.2 # via -r requirements-dev.in traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -r requirements-dev.in +types-tabulate==0.9.0.20241207 # via -r requirements-dev.in typing-extensions==4.12.2 # via annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via requests -virtualenv==20.27.1 # via pre-commit, tox +virtualenv==20.28.0 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit -wheel==0.45.0 # via astunparse, pip-tools +wheel==0.45.1 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index d7679a1f0f..6d75415181 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,7 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 -filelock==3.0.0 +filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index cf505e88d6..991b7a6941 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,7 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 -filelock==3.0.0 +filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index e859c9b4f7..d086363ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', - 'filelock>=3.0.0', + 'filelock>=3.16.1', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/requirements-dev.txt b/requirements-dev.txt index 6542be36f1..40554cef13 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -23,9 +23,9 @@ certifi==2024.8.30 # via -c constraints.txt, requests cfgv==3.4.0 # via -c constraints.txt, pre-commit chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests -clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +clang-format==19.1.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.8 # via -c constraints.txt, ipykernel +debugpy==1.8.9 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) @@ -47,11 +47,11 @@ exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==33.0.0 # via -c constraints.txt, factory-boy -fastjsonschema==2.20.0 # via -c constraints.txt, nbformat +faker==33.1.0 # via -c constraints.txt, factory-boy +fastjsonschema==2.21.1 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.0 # via -c constraints.txt, matplotlib -fparser==0.1.4 # via -c constraints.txt, dace +fonttools==4.55.2 # via -c constraints.txt, matplotlib +fparser==0.2.0 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach @@ -75,7 +75,7 @@ jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, n jupytext==1.16.4 # via -c constraints.txt, -r requirements-dev.in kiwisolver==1.4.7 # via -c constraints.txt, matplotlib lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) -mako==1.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +mako==1.3.8 # via -c constraints.txt, gt4py (pyproject.toml) markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako matplotlib==3.7.5 # via -c constraints.txt, -r requirements-dev.in @@ -85,13 +85,13 @@ mdurl==0.1.2 # via -c constraints.txt, markdown-it-py mpmath==1.3.0 # via -c constraints.txt, sympy mypy==1.13.0 # via -c constraints.txt, -r requirements-dev.in mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy -nanobind==2.2.0 # via -c constraints.txt, gt4py (pyproject.toml) +nanobind==2.4.0 # via -c constraints.txt, gt4py (pyproject.toml) nbclient==0.6.8 # via -c constraints.txt, nbmake nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake nbmake==1.5.4 # via -c constraints.txt, -r requirements-dev.in nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel, nbclient networkx==3.1 # via -c constraints.txt, dace, tach -ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) +ninja==1.11.1.2 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff @@ -102,7 +102,7 @@ pexpect==4.9.0 # via -c constraints.txt, ipython pickleshare==0.7.5 # via -c constraints.txt, ipython pillow==10.4.0 # via -c constraints.txt, matplotlib pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in -pipdeptree==2.23.4 # via -c constraints.txt, -r requirements-dev.in +pipdeptree==2.24.0 # via -c constraints.txt, -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema platformdirs==4.3.6 # via -c constraints.txt, black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via -c constraints.txt, pytest, tox @@ -113,15 +113,15 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.27.0 # via -c constraints.txt, pydantic +pydantic==2.10.3 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.1 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version -pydot==3.0.2 # via -c constraints.txt, tach +pydot==3.0.3 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via -c constraints.txt, matplotlib, pydot pyproject-api==1.8.0 # via -c constraints.txt, tox pyproject-hooks==1.2.0 # via -c constraints.txt, build, pip-tools -pytest==8.3.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist +pytest==8.3.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in pytest-cov==5.0.0 # via -c constraints.txt, -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -c constraints.txt, -r requirements-dev.in @@ -137,11 +137,11 @@ questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.4 # via -c constraints.txt, bump-my-version +rich-click==1.8.5 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.4 # via -c constraints.txt, -r requirements-dev.in +ruff==0.8.2 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser -six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil +six==1.17.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis @@ -158,21 +158,21 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in -tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tach==0.16.5 # via -c constraints.txt, -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version toolz==1.0.0 # via -c constraints.txt, cytoolz -tornado==6.4.1 # via -c constraints.txt, ipykernel, jupyter-client +tornado==6.4.2 # via -c constraints.txt, ipykernel, jupyter-client tox==4.23.2 # via -c constraints.txt, -r requirements-dev.in traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -c constraints.txt, -r requirements-dev.in +types-tabulate==0.9.0.20241207 # via -c constraints.txt, -r requirements-dev.in typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via -c constraints.txt, requests -virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox +virtualenv==20.28.0 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.1 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 1b88285475..c0bf9580b3 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -27,6 +27,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 shadowing a Python builtin __all__ += ["next"] diff --git a/src/gt4py/cartesian/__init__.py b/src/gt4py/cartesian/__init__.py index c03ef15105..90df315d5c 100644 --- a/src/gt4py/cartesian/__init__.py +++ b/src/gt4py/cartesian/__init__.py @@ -27,7 +27,7 @@ __all__ = [ - "typing", + "StencilObject", "caching", "cli", "config", @@ -39,5 +39,5 @@ "stencil_builder", "stencil_object", "type_hints", - "StencilObject", + "typing", ] diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index e58c7a01a7..4296e3b389 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -32,9 +32,9 @@ "BasePyExtBackend", "CLIBackendMixin", "CudaBackend", - "GTGpuBackend", "GTCpuIfirstBackend", "GTCpuKfirstBackend", + "GTGpuBackend", "NumpyBackend", "PurePythonBackendCLIMixin", "from_name", diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index 91daed9e98..4ea5e44074 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -90,7 +90,7 @@ def backend_table(cls) -> str: ", ".join(backend.languages["bindings"]) if backend and backend.languages else "?" for backend in backends ] - enabled = [backend is not None and "Yes" or "No" for backend in backends] + enabled = [(backend is not None and "Yes") or "No" for backend in backends] data = zip(names, comp_langs, binding_langs, enabled) return tabulate.tabulate(data, headers=headers) diff --git a/src/gt4py/cartesian/frontend/__init__.py b/src/gt4py/cartesian/frontend/__init__.py index 6988fb6aab..f1e0f9a775 100644 --- a/src/gt4py/cartesian/frontend/__init__.py +++ b/src/gt4py/cartesian/frontend/__init__.py @@ -10,4 +10,4 @@ from .base import REGISTRY, Frontend, from_name, register -__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"] +__all__ = ["REGISTRY", "Frontend", "from_name", "gtscript_frontend", "register"] diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 643ecba010..59f3ef37c2 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -657,10 +657,8 @@ def __str__(self) -> str: class _FieldDescriptorMaker: @staticmethod def _is_axes_spec(spec) -> bool: - return ( - isinstance(spec, Axis) - or isinstance(spec, collections.abc.Collection) - and all(isinstance(i, Axis) for i in spec) + return isinstance(spec, Axis) or ( + isinstance(spec, collections.abc.Collection) and all(isinstance(i, Axis) for i in spec) ) def __getitem__(self, field_spec): diff --git a/src/gt4py/cartesian/testing/__init__.py b/src/gt4py/cartesian/testing/__init__.py index 288d7b1d2d..0753b4175e 100644 --- a/src/gt4py/cartesian/testing/__init__.py +++ b/src/gt4py/cartesian/testing/__init__.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -__all__ = ["field", "global_name", "none", "parameter", "StencilTestSuite"] +__all__ = ["StencilTestSuite", "field", "global_name", "none", "parameter"] try: from .input_strategies import field, global_name, none, parameter from .suites import StencilTestSuite diff --git a/src/gt4py/cartesian/utils/__init__.py b/src/gt4py/cartesian/utils/__init__.py index 3c0bdb3fc3..626d29b167 100644 --- a/src/gt4py/cartesian/utils/__init__.py +++ b/src/gt4py/cartesian/utils/__init__.py @@ -37,7 +37,7 @@ ) -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # Modules "attrib", "meta", diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index d5d43a4103..35184a3f7b 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -63,10 +63,8 @@ def flatten_iter(nested_iterables, filter_none=False, *, skip_types=(str, bytes) def get_member(instance, item_name): try: - if ( - isinstance(instance, collections.abc.Mapping) - or isinstance(instance, collections.abc.Sequence) - and isinstance(item_name, int) + if isinstance(instance, collections.abc.Mapping) or ( + isinstance(instance, collections.abc.Sequence) and isinstance(item_name, int) ): return instance[item_name] else: diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 5adac47da3..e6044f15ef 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -71,7 +71,7 @@ from .visitors import NodeTranslator, NodeVisitor -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # version "__version__", "__version_info__", diff --git a/src/gt4py/eve/datamodels/validators.py b/src/gt4py/eve/datamodels/validators.py index 119410460c..4ce6f94c5e 100644 --- a/src/gt4py/eve/datamodels/validators.py +++ b/src/gt4py/eve/datamodels/validators.py @@ -42,7 +42,7 @@ from .core import DataModelTP, FieldValidator -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # reexported from attrs "and_", "deep_iterable", diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 89f78a45e4..9febe098a4 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -23,9 +23,9 @@ __all__ = [ "DSLError", "InvalidParameterAnnotationError", + "MissingArgumentError", "MissingAttributeError", "MissingParameterAnnotationError", - "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", ] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b60fa63f95..1210e96efc 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -10,7 +10,7 @@ import functools import inspect import math -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index e47a6886ad..c9a5b15de7 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"] +__all__ = ["fendef", "fundef", "if_stmt", "offset", "set_at"] @dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index d0afc610e7..1d91254ee8 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -13,4 +13,4 @@ ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] +__all__ = ["GTIRTransform", "apply_common_transforms", "apply_fieldview_transforms"] diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index e8a221b814..661b456608 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -240,8 +240,10 @@ def visit_FunCall(self, node: itir.FunCall): or ( isinstance(arg, itir.FunCall) and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) + ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + ) or cpm.is_call_to(arg, "if_") ) and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 2232bcef01..4f3efb19b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -43,25 +43,25 @@ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", "GT4PyGlobalSelfCopyElimination", - "GT4PyMoveTaskletIntoMap", "GT4PyMapBufferElimination", + "GT4PyMoveTaskletIntoMap", "LoopBlocking", - "MapIterationOrder", "MapFusionParallel", "MapFusionSerial", + "MapIterationOrder", "SerialMapPromoter", "SerialMapPromoterGPU", "gt_auto_optimize", "gt_change_transient_strides", "gt_create_local_double_buffering", + "gt_find_constant_arguments", + "gt_gpu_transform_non_standard_memlet", "gt_gpu_transformation", "gt_inline_nested_sdfg", - "gt_set_iteration_order", - "gt_set_gpu_blocksize", - "gt_simplify", "gt_make_transients_persistent", "gt_reduce_distributed_buffering", - "gt_find_constant_arguments", + "gt_set_gpu_blocksize", + "gt_set_iteration_order", + "gt_simplify", "gt_substitute_compiletime_symbols", - "gt_gpu_transform_non_standard_memlet", ] diff --git a/src/gt4py/storage/__init__.py b/src/gt4py/storage/__init__.py index 4866cd480c..5986baa65e 100644 --- a/src/gt4py/storage/__init__.py +++ b/src/gt4py/storage/__init__.py @@ -16,12 +16,12 @@ __all__ = [ "cartesian", - "layout", "empty", "from_array", + "from_name", "full", + "layout", "ones", - "zeros", - "from_name", "register", + "zeros", ] From 98889056c914886912d9131793deb67b5f947602 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Dec 2024 10:02:22 +0100 Subject: [PATCH 41/43] feat[next]: Change interval syntax in ITIR pretty printer (#1766) We currently use `)` in the pretty printer to express an open interval. This is quite cumbersome when debugging the IR because it breaks matching parenthesis in the editor of functions and calls, e.g. when does a function start and end. This PR simply uses `[` instead. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +++--- src/gt4py/next/iterator/pretty_parser.py | 2 +- src/gt4py/next/iterator/pretty_printer.py | 4 +++- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 6 +++--- src/gt4py/next/iterator/transforms/inline_fundefs.py | 2 +- .../unit_tests/iterator_tests/test_pretty_parser.py | 4 ++-- .../unit_tests/iterator_tests/test_pretty_printer.py | 2 +- 7 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index a4e111e785..0839e95b5b 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -423,11 +423,11 @@ def domain( ... }, ... ) ... ) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' >>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)})) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' >>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)})) - 'u⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 29b30beae1..a077b39911 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -84,7 +84,7 @@ else_branch_seperator: "else" if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" - named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")" + named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 "[" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index a25f99356c..7acbf5d23d 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -190,7 +190,9 @@ def visit_FunCall(self, node: ir.FunCall, *, prec: int) -> list[str]: if fun_name == "named_range" and len(node.args) == 3: # named_range(dim, start, stop) → dim: [star, stop) dim, start, end = self.visit(node.args, prec=0) - res = self._hmerge(dim, [": ["], start, [", "], end, [")"]) + res = self._hmerge( + dim, [": ["], start, [", "], end, ["["] + ) # to get matching parenthesis of functions return self._prec_parens(res, prec, PRECEDENCE["__call__"]) if fun_name == "cartesian_domain" and len(node.args) >= 1: # cartesian_domain(x, y, ...) → c{ x × y × ... } # noqa: RUF003 [ambiguous-unicode-character-comment] diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 661b456608..b7087472e0 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -186,15 +186,15 @@ class FuseAsFieldOp(eve.NodeTranslator): ... im.ref("inp3", field_type), ... ) >>> print(nested_as_fieldop) - as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( - as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2), inp3 ) >>> print( ... FuseAsFieldOp.apply( ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) - as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) """ # noqa: RUF002 # ignore ambiguous multiplication character uids: eve_utils.UIDGenerator diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index a2188030a1..e4cae978da 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -59,7 +59,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> print(prune_unreferenced_fundefs(program)) testee(inp, out) { fun1 = λ(a) → ·a; - out @ c⟨ IDimₕ: [0, 10) ⟩ ← fun1(inp); + out @ c⟨ IDimₕ: [0, 10[ ⟩ ← fun1(inp); } """ fun_names = [fun.id for fun in program.function_definitions] diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index bf47f997d6..af9084f407 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -127,7 +127,7 @@ def test_make_tuple(): def test_named_range_horizontal(): - testee = "IDimₕ: [x, y)" + testee = "IDimₕ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], @@ -137,7 +137,7 @@ def test_named_range_horizontal(): def test_named_range_vertical(): - testee = "IDimᵥ: [x, y)" + testee = "IDimᵥ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 11f50dbf6d..6b45f470b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -233,7 +233,7 @@ def test_named_range_horizontal(): fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], ) - expected = "IDimₕ: [x, y)" + expected = "IDimₕ: [x, y[" actual = pformat(testee) assert actual == expected From 06b398af7c5a4235d2c595bbbac93ec70f31a5a6 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Dec 2024 15:32:16 +0100 Subject: [PATCH 42/43] refact[next][dace]: split handling of let-statement lambdas from stencil body (#1781) This is a refactoring of the code to lower lambda nodes: it splits the lowering of let-statements from the lowering of stencil expressions. --- .../gtir_builtin_translators.py | 43 ++--- .../runners/dace_fieldview/gtir_dataflow.py | 165 +++++++++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 5 +- 3 files changed, 143 insertions(+), 70 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ff011c4193..cffbd74c90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -30,7 +30,7 @@ gtir_python_codegen, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: @@ -39,7 +39,7 @@ def _get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None -) -> sbs.Indices: +) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying an optional offset in each dimension as start index. @@ -55,9 +55,9 @@ def _get_domain_indices( """ index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] if offsets is None: - return sbs.Indices(index_variables) + return dace_subsets.Indices(index_variables) else: - return sbs.Indices( + return dace_subsets.Indices( [ index - offset if offset != 0 else index for index, offset in zip(index_variables, offsets, strict=True) @@ -96,7 +96,7 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): @@ -263,7 +263,7 @@ def _create_field_operator( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = sbs.Range.from_indices(field_indices) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -280,7 +280,7 @@ def _create_field_operator( field_dims.append(output_edge.result.gt_dtype.offset_type) field_shape.extend(dataflow_output_desc.shape) field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) @@ -366,36 +366,37 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args + fieldop_expr, domain_expr = fun_node.args - if isinstance(stencil_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - pass - elif cpm.is_ref_to(stencil_expr, "deref"): + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] + stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( - f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." ) # parse the domain of the field operator domain = extract_domain(domain_expr) # visit the list of arguments to be passed to the lambda expression - stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) + input_edges, output_edge = gtir_dataflow.visit_lambda( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge @@ -654,7 +655,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) + assert ti.is_integral(node.args[0].type) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cfba4d61e5..a3653fb519 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,10 +10,22 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + Final, + List, + Optional, + Protocol, + Sequence, + Set, + Tuple, + TypeAlias, + Union, +) import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py import eve from gt4py.next import common as gtx_common @@ -68,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Indices | sbs.Range + subset: dace_subsets.Range @dataclasses.dataclass(frozen=True) @@ -104,7 +116,7 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] - def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -117,7 +129,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: assert len(field_desc.shape) == len(self.field_domain) field_domain = self.field_domain - return sbs.Range.from_string( + return dace_subsets.Range.from_string( ",".join( str(self.indices[dim].value - offset) # type: ignore[union-attr] if dim in self.indices @@ -152,7 +164,7 @@ class MemletInputEdge(DataflowInputEdge): state: dace.SDFGState source: dace.nodes.AccessNode - subset: sbs.Range + subset: dace_subsets.Range dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] @@ -202,7 +214,7 @@ def connect( self, mx: dace.nodes.MapExit, dest: dace.nodes.AccessNode, - subset: sbs.Range, + subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src @@ -256,10 +268,12 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +@dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ - Translates an `ir.Lambda` expression to a dataflow graph. + Visitor class to translate a `Lambda` expression to a dataflow graph. + This visitor should be applied by calling `apply()` method on a `Lambda` IR. The dataflow graph generated here typically represents the stencil function of a field operator. It only computes single elements or pure local fields, in case of neighbor values. In case of local fields, the dataflow contains @@ -275,25 +289,15 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.input_edges = [] - self.symbol_map = {} + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( + default_factory=lambda: {} + ) def _add_input_data_edge( self, src: dace.nodes.AccessNode, - src_subset: sbs.Range, + src_subset: dace_subsets.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, src_offset: Optional[list[dace.symbolic.SymExpr]] = None, @@ -301,7 +305,7 @@ def _add_input_data_edge( input_subset = ( src_subset if src_offset is None - else sbs.Range( + else dace_subsets.Range( (start - off, stop - off, step) for (start, stop, step), off in zip(src_subset, src_offset, strict=True) ) @@ -512,7 +516,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, - sbs.Range.from_array(field_desc), + dace_subsets.Range.from_array(field_desc), deref_node, "field", src_offset=[offset for (_, offset) in arg_expr.field_domain], @@ -580,7 +584,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=it.field, gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( ",".join( str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain @@ -596,7 +600,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=self.state.add_access(connectivity), gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_index.value}, 0:{offset_provider.max_neighbors}" ), ) @@ -758,7 +762,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=itir_ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) @@ -908,7 +912,9 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), nsdfg_node, "neighbor_indices", ) @@ -1081,7 +1087,7 @@ def _make_dynamic_neighbor_offset( ) self._add_input_data_edge( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) @@ -1127,7 +1133,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=sbs.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1264,39 +1270,39 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: elif cpm.is_applied_shift(node): return self._visit_shift(node) + elif isinstance(node.fun, gtir.Lambda): + # Lambda node should be visited with 'visit_let()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") + elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: DataExpr = self.visit(node.expr) - if isinstance(output_expr, ValueExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: + result: DataExpr = self.visit(node.expr) + + if isinstance(result, ValueExpr): + return DataflowOutputEdge(self.state, result) - if isinstance(output_expr, MemletExpr): + if isinstance(result, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + output_dtype = result.dc_node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_data_edge( - output_expr.dc_node, - output_expr.subset, + result.dc_node, + result.subset, tasklet_node, "__inp", ) else: - assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + output_dtype = result.dc_dtype + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + return DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) @@ -1309,3 +1315,68 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE # if not in the lambda symbol map, this must be a symref to a builtin function assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) + + def visit_let( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> DataflowOutputEdge: + """ + Maps lambda arguments to internal parameters. + + This method is responsible to recognize the usage of the `Lambda` node, + which can be either a let-statement or the stencil expression in local view. + The usage of a `Lambda` as let-statement corresponds to computing some results + and making them available inside the lambda scope, represented as a nested SDFG. + All let-statements, if any, are supposed to be encountered before the stencil + expression. In other words, the `Lambda` node representing the stencil expression + is always the innermost node. + Therefore, the lowering of let-statements results in recursive calls to + `visit_let()` until the stencil expression is found. At that point, it falls + back to the `visit()` function. + """ + + # lambda arguments are mapped to symbols defined in lambda scope. + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + return self.visit_let(let_node.fun, args=let_args) + else: + # this lambda node is not a let-statement, but a stencil expression + return self.visit(node) + + +def visit_lambda( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], +) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. + + It calls `LambdaToDataflow.visit_let()` to map the lambda arguments to internal + parameters and visit the let-statements (if any), which always appear as outermost + nodes. Finally, the visitor returns the output edge of the dataflow. + + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output data connection. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + output_edge = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 6b5e164458..9bd40f75f8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -602,7 +602,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - args: list[gtir_builtin_translators.FieldopResult], + args: Sequence[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -679,7 +679,7 @@ def get_field_domain_offset( self.offset_provider_type, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nstate = nsdfg.add_state("lambda") + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # add sdfg storage for the symbols that need to be passed as input parameters lambda_params = [ @@ -690,6 +690,7 @@ def get_field_domain_offset( nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) + nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( node.expr, sdfg=nsdfg, From 77cad7c8862c6164dff5f9e192ffef8fc9a2b1af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:53:40 +0100 Subject: [PATCH 43/43] feat[dace][next]: Fixing strides in optimization (#1782) Added functionality to properly handle changes of strides. During the implementation of the scan we found that the strides were not handled properly. Most importantly a change on one level was not propagated into the next levels, i.e. they were still using the old strides. This PR Solves most of the problems, but there are still some issues that are unsolved: - Views are not adjusted yet (Fixed in [PR@1784](https://github.com/GridTools/gt4py/pull/1784)). - It is not properly checked if the symbols of the propagated strides are safe to introduce into the nested SDFG. The initial functionality of this PR was done by Edoardo Paone (@edopao). --------- Co-authored-by: edopao --- .../transformations/__init__.py | 12 +- .../transformations/gpu_utils.py | 2 +- .../transformations/simplify.py | 5 +- .../dace_fieldview/transformations/strides.py | 611 +++++++++++++++++- .../test_map_buffer_elimination.py | 93 ++- .../transformation_tests/test_strides.py | 541 ++++++++++++++++ 6 files changed, 1238 insertions(+), 26 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4f3efb19b0..0902bd665a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,13 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, + gt_propagate_strides_from_access_node, + gt_propagate_strides_of, +) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,6 +65,10 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", + "gt_propagate_strides_from_access_node", + "gt_propagate_strides_of", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 2cd3020180..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -95,7 +95,7 @@ def gt_gpu_transformation( if try_removing_trivial_maps: # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So # we might end up with lots of these trivial Maps, each requiring a separate # kernel launch. To prevent this we will combine these trivial maps, if # possible, with their downstream maps. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..4339a761fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - graph: dace.SDFGState | dace.SDFG, + graph: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -971,6 +971,9 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + # Recursively visit the nested SDFGs for mapping of strides from inner to outer array + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 4e254f2880..980b2a8fdf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,14 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional, TypeAlias + import dace from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, ) +PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] +"""Record of a stride that has been propagated into a NestedSDFG. + +The type combines the NestedSDFG into which the strides were already propagated +and the data within that NestedSDFG to which we have propagated the strides, +which is the connector name on the NestedSDFG. +We need the NestedSDFG because we have to know what was already processed, +however, we also need the inner array name because of aliasing, i.e. a data +descriptor on the outside could be mapped to multiple data descriptors +inside the NestedSDFG. +""" + + def gt_change_transient_strides( sdfg: dace.SDFG, gpu: bool, @@ -24,6 +40,11 @@ def gt_change_transient_strides( transients in the optimal way. The function should run after all maps have been created. + After the strides have been adjusted the function will also propagate + the strides into nested SDFG. This propagation will happen with + `ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` + for more. + Args: sdfg: The SDFG to process. gpu: If the SDFG is supposed to run on the GPU. @@ -35,8 +56,6 @@ def gt_change_transient_strides( Todo: - Implement the estimation correctly. - - Handle the case of nested SDFGs correctly; on the outside a transient, - but on the inside a non transient. """ # TODO(phimeull): Implement this function correctly. @@ -46,54 +65,608 @@ def gt_change_transient_strides( return sdfg for nsdfg in sdfg.all_sdfgs_recursive(): - # TODO(phimuell): Handle the case when transient goes into nested SDFG - # on the inside it is a non transient, so it is ignored. _gt_change_transient_strides_non_recursive_impl(nsdfg) def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order.""" - for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + """Set optimal strides of all transients in the SDFG. + + The function will look for all top level transients, see `_gt_find_toplevel_data_accesses()` + and set their strides such that the access is optimal, see Note. The function + will also run `gt_propagate_strides_of()` to propagate the strides into nested SDFGs. + + This function should never be called directly but always through + `gt_change_transient_strides()`! + + Note: + Currently the function just reverses the strides of the data descriptor + it processes. Since DaCe generates `C` order by default this lead to + FORTRAN order, which is (for now) sufficient to optimize the memory + layout to GPU. + + Todo: + Make this function more intelligent to analyse the access pattern and then + figuring out the best order. + """ + # NOTE: Processing the transient here is enough. If we are inside a + # NestedSDFG then they were handled before on the level above us. + top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( + sdfg=sdfg, + only_transients=True, + only_arrays=True, + ) + for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] + + # Setting the strides only make sense if we have more than one dimensions ndim = len(desc.shape) if ndim <= 1: continue + # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. + # TODO(phimuell): Improve this. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) + # Now we have to propagate the changed strides. Because we already have + # collected all the AccessNodes we are using the + # `gt_propagate_strides_from_access_node()` function, but we have to + # create `processed_nsdfg` set already outside here. + # Furthermore, the same comment as above applies here, we do not have to + # propagate the non-transients, because they either come from outside, + # or they were already handled in the levels above, where they were + # defined and then propagated down. + # TODO(phimuell): Updated the functions such that only one scan is needed. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state, access_node in accesses: + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=access_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=True, + ) + + +def gt_propagate_strides_of( + sdfg: dace.SDFG, + data_name: str, + ignore_symbol_mapping: bool = True, +) -> None: + """Propagates the strides of `data_name` within the whole SDFG. + + This function will call `gt_propagate_strides_from_access_node()` for every + AccessNode that refers to `data_name`. It will also make sure that a descriptor + inside a NestedSDFG is only processed once. + + Args: + sdfg: The SDFG on which we operate. + data_name: Name of the data descriptor that should be handled. + ignore_symbol_mapping: If `False` (default is `True`) try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + """ + + # Defining it here ensures that we will not enter an NestedSDFG multiple times. + processed_nsdfgs: set[PropagatedStrideRecord] = set() + + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode.data != data_name: + continue + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=dnode, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_propagate_strides_from_access_node( + sdfg: dace.SDFG, + state: dace.SDFGState, + outer_node: dace_nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the stride of `outer_node` to any adjacent NestedSDFG. + + The function will propagate the strides of the data descriptor `outer_node` + refers to along all adjacent edges of `outer_node`. If one of these edges + leads to a NestedSDFG then the function will modify the strides of data + descriptor within to match the strides on the outside. The function will then + recursively process NestedSDFG. + + It is important that this function will only handle the NestedSDFGs that are + reachable from `outer_node`. To fully propagate the strides the + `gt_propagate_strides_of()` should be used. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False` (default is `True`), try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + """ + if processed_nsdfgs is None: + # For preventing the case that nested SDFGs are handled multiple time. + processed_nsdfgs = set() + + for in_edge in state.in_edges(outer_node): + gt_map_strides_to_src_nested_sdfg( + sdfg=sdfg, + state=state, + edge=in_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + for out_edge in state.out_edges(outer_node): + gt_map_strides_to_dst_nested_sdfg( + sdfg=sdfg, + state=state, + edge=out_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the dataflow direction. + + In this context "along the dataflow direction" means that `edge` is an outgoing + edge of `outer_node` and the strides are propagated into all NestedSDFGs that + are downstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + assert edge.src is outer_node + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=True, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the opposite direction of the dataflow + + In this context "in the opposite direction of the dataflow" means that `edge` + is an incoming edge of `outer_node` and the strides are propagated into all + NestedSDFGs that are upstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=False, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_to_nested_sdfg_src_dst( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]], + propagate_along_dataflow: bool, + ignore_symbol_mapping: bool = False, +) -> None: + """Propagates the stride of `outer_node` along `edge`. + + The function will follow `edge`, the direction depends on the value of + `propagate_along_dataflow` and propagate the strides of `outer_node` + into every NestedSDFG that is reachable by following `edge`. + + When the function encounters a NestedSDFG it will determine what data + the `outer_node` is mapped to on the inside of the NestedSDFG. + It will then replace the stride of the inner descriptor with the ones + of the outside. Afterwards it will recursively propagate the strides + inside the NestedSDFG. + During this propagation the function will follow any edges. + + If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` + then it will be skipped. NestedSDFGs that have been processed will be added + to the `processed_nsdfgs`. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + + Note: + A user should not use this function directly, instead `gt_propagate_strides_of()`, + `gt_map_strides_to_src_nested_sdfg()` (`propagate_along_dataflow == `False`) + or `gt_map_strides_to_dst_nested_sdfg()` (`propagate_along_dataflow == `True`) + should be used. + + Todo: + Try using `MemletTree` for the propagation. + """ + # If `processed_nsdfg` is `None` then this is the first call. We will now + # allocate the `set` and pass it as argument to all recursive calls, this + # ensures that the `set` is the same everywhere. + if processed_nsdfgs is None: + processed_nsdfgs = set() + + if propagate_along_dataflow: + # Propagate along the dataflow or forward, so we are interested at the `dst` of the edge. + ScopeNode = dace_nodes.MapEntry + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.dst + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.dst_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_src_subset(edge, state) -def _find_toplevel_transients( + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): + return [] + return list(state.out_edges_by_connector(edge.dst, "OUT_" + edge.dst_conn[3:])) + + else: + # Propagate against the dataflow or backward, so we are interested at the `src` of the edge. + ScopeNode = dace_nodes.MapExit + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.src + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.src_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_dst_subset(edge, state) + + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) + + if isinstance(get_node(edge), ScopeNode): + for next_edge in next_edges_by_connector(state, edge): + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=next_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=propagate_along_dataflow, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + elif isinstance(get_node(edge), dace.nodes.NestedSDFG): + nsdfg_node = get_node(edge) + inner_data = get_inner_data(edge) + process_record = (inner_data, nsdfg_node) + + if process_record in processed_nsdfgs: + # We already handled this NestedSDFG and the inner data. + return + + # Mark this nested SDFG as processed. + processed_nsdfgs.add(process_record) + + # Now set the stride of the data descriptor inside the nested SDFG to + # the ones it has outside. + _gt_map_strides_into_nested_sdfg( + sdfg=sdfg, + nsdfg_node=nsdfg_node, + inner_data=inner_data, + outer_subset=get_subset(state, edge), + outer_desc=outer_node.desc(sdfg), + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + # Since the function call above is not recursive we have now to propagate + # the change into the NestedSDFGs. Using `_gt_find_toplevel_data_accesses()` + # is a bit overkill, but allows for a more uniform processing. + # TODO(phimuell): Instead of scanning every level for every data we modify + # we should scan the whole SDFG once and then reuse this information. + accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( + sdfg=nsdfg_node.sdfg, + only_transients=False, # Because on the nested levels they are globals. + only_arrays=True, + ) + for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): + # We have to use `gt_propagate_strides_from_access_node()` here because we + # have to handle its entirety. We could wait until the other branch processes + # the nested SDFG, but this might not work, so let's do it fully now. + gt_propagate_strides_from_access_node( + sdfg=nsdfg_node.sdfg, + state=nested_state, + outer_node=nested_access, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_into_nested_sdfg( sdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + outer_subset: dace.subsets.Subset, + outer_desc: dace_data.Data, + ignore_symbol_mapping: bool, +) -> None: + """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. + + `inner_data` is the name of a data descriptor inside the NestedSDFG. + The function will then modify the strides of `inner_data`, assuming this + is an array, to match the ones of `outer_desc`. + + Args: + sdfg: The SDFG containing the NestedSDFG. + nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. + inner_data: The name of the data descriptor that should be processed + inside the NestedSDFG (by construction also a connector name). + outer_subset: The subset that describes what part of the outer data is + mapped into the NestedSDFG. + outer_desc: The data descriptor of the data on the outside. + ignore_symbol_mapping: If possible the function will perform the renaming + through the `symbol_mapping` of the nested SDFG. If `True` then + the function will always perform the renaming. + Note that setting this value to `False` might have negative side effects. + + Todo: + - Handle explicit dimensions of size 1. + - What should we do if the stride symbol is used somewhere else, creating an + alias is probably not the right thing? + - Handle the case if the outer stride symbol is already used in another + context inside the Neste SDFG. + """ + # We need to compute the new strides. In the following we assume that the + # relative order of the dimensions does not change, but we support the case + # where some dimensions of the outer data descriptor are not present on the + # inside. For example this happens for the Memlet `a[__i0, 0:__a_size1]`. We + # detect this case by checking if the Memlet subset in that dimension has size 1. + # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. + inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] + inner_shape = inner_desc.shape + inner_strides_init = inner_desc.strides + + outer_strides = outer_desc.strides + outer_inflow = outer_subset.size() + + new_strides: list = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + # This is the case of implicit slicing along one dimension. + pass + else: + # There is inflow into the SDFG, so we need the stride. + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + + # If we have a scalar on the inside, then there is nothing to adjust. + # We could have performed the test above, but doing it here, gives us + # the chance of validating it. + if isinstance(inner_desc, dace_data.Scalar): + if len(new_strides) != 0: + raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") + return + + if not isinstance(inner_desc, dace_data.Array): + raise TypeError( + f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." + ) + + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + + # Now we actually replace the strides, there are two ways of doing it. + # The first is to create an alias in the `symbol_mapping`, however, + # this is only possible if the current strides are singular symbols, + # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` + # or literal values. Furthermore, this would change the meaning of the + # old stride symbol in any context and not only in the one of the stride + # of a single and isolated data descriptor. + # The second way would be to replace `strides` attribute of the + # inner data descriptor. In case the new stride consists of expressions + # such as `value1 - value2` we have to make them available inside the + # NestedSDFG. However, it could be that the strides is used somewhere else. + # We will do the following, if `ignore_symbol_mapping` is `False` and + # the strides of the inner descriptors are symbols, we will use the + # symbol mapping. Otherwise, we will replace the `strides` attribute + # of the inner descriptor, in addition we will install a remapping, + # for those values that were a symbol. + if (not ignore_symbol_mapping) and all( + isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init + ): + # Use the symbol + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + # We have to replace the `strides` attribute of the inner descriptor. + inner_desc.set_shape(inner_desc.shape, new_strides) + + # Now find the free symbols that the new strides need. + # Note that usually `free_symbols` returns `set[str]`, but here, because + # we fall back on SymPy, we get back symbols. We will keep them, because + # then we can use them to extract the type form them, which we need later. + new_strides_symbols: list[dace.symbol] = [] + for new_stride_dim in new_strides: + if dace.symbolic.issymbolic(new_stride_dim): + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + else: + # It is not already a symbol, so we turn it into a symbol. + # However, we only add it, if it is also a symbol, for example `1`. + # should not be added. + new_stride_symbol = dace.symbolic.pystr_to_symbolic(new_stride_dim) + if new_stride_symbol.is_symbol: + new_strides_symbols.append(new_stride_symbol) + + # Now we determine the set of symbols that should be mapped inside the NestedSDFG. + # We will exclude all that are already inside the `symbol_mapping` (we do not + # check if they map to the same value, we just hope it). Furthermore, + # we will exclude all symbols that are listed in the `symbols` property + # of the SDFG that is nested, and hope that it has the same meaning. + # TODO(phimuell): Add better checks to avoid overwriting. + missing_symbol_mappings: set[dace.symbol] = { + sym + for sym in new_strides_symbols + if not (sym.name in nsdfg_node.sdfg.symbols or sym.name in nsdfg_node.symbol_mapping) + } + + # Now propagate the symbols from the parent SDFG to the NestedSDFG. + for sym in missing_symbol_mappings: + assert sym.name in sdfg.symbols, f"Expected that '{sym}' is defined in the parent SDFG." + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) + nsdfg_node.symbol_mapping[sym.name] = sym + + +def _gt_find_toplevel_data_accesses( + sdfg: dace.SDFG, + only_transients: bool, only_arrays: bool = False, -) -> set[str]: - """Find all top level transients in the SDFG. +) -> dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]]: + """Find all data that is accessed on the top level. The function will scan the SDFG, ignoring nested one, and return the - name of all transients that have an access node at the top level. - However, it will ignore access nodes that refers to registers. + name of all data that only have AccessNodes on the top level. In data + is found that has an AccessNode on both the top level and in a nested + scope and error is generated. + By default the function will return transient and non transient data, + however, if `only_transients` is `True` then only transient data will + be returned. + Furthermore, the function will ignore an access in the following cases: + - The AccessNode refers to data that is a register. + - The AccessNode refers to a View. + + Args: + sdfg: The SDFG to process. + only_transients: If `True` only include transients. + only_arrays: If `True`, defaults to `False`, only arrays are returned. + + Returns: + A `dict` that maps the name of a data container, to a list of tuples + containing the state where the AccessNode was found and the AccessNode. """ - top_level_transients: set[str] = set() + # List of data that is accessed on the top level and all its access node. + top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() + + # List of all data that were found not on top level. + not_top_level_data: set[str] = set() + for state in sdfg.states(): scope_dict = state.scope_dict() for dnode in state.data_nodes(): data: str = dnode.data if scope_dict[dnode] is not None: - if data in top_level_transients: - top_level_transients.remove(data) + # The node was not found on the top level. So we can ignore it. + # We also check if it was ever found on the top level, this should + # not happen, as everything should go through Maps. But some strange + # DaCe transformation might do it. + assert ( + data not in top_level_data + ), f"Found {data} on the top level and inside a scope." + not_top_level_data.add(data) continue - elif data in top_level_transients: + + elif data in top_level_data: + # The data is already known to be in top level data, so we must add the + # AccessNode to the list of known nodes. But nothing else. + top_level_data[data].append((state, dnode)) continue + elif gtx_transformations.util.is_view(dnode, sdfg): + # The AccessNode refers to a View so we ignore it anyway. continue + + # We have found a new data node that is on the top node and is unknown. + assert ( + data not in not_top_level_data + ), f"Found {data} on the top level and inside a scope." desc: dace_data.Data = dnode.desc(sdfg) - if not desc.transient: + # Check if we only accept arrays + if only_arrays and not isinstance(desc, dace_data.Array): continue - elif only_arrays and not isinstance(desc, dace_data.Array): + + # For now we ignore registers. + # We do this because register are allocated on the stack, so the compiler + # has all information and should organize the best thing possible. + # TODO(phimuell): verify this. + elif desc.storage is dace.StorageType.Register: continue - top_level_transients.add(data) - return top_level_transients + + # We are only interested in transients + if only_transients and (not desc.transient): + continue + + # Now create the new entry in the list and record the AccessNode. + top_level_data[data] = [(state, dnode)] + return top_level_data diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index 1a4ce6d047..a98eac3c2c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -22,10 +22,6 @@ import dace -def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: - return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} - - def _make_test_sdfg( output_name: str = "G", input_name: str = "G", @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply(): validate_all=True, ) assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py new file mode 100644 index 0000000000..5b16e41bc3 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -0,0 +1,541 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace import symbolic as dace_symbolic +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_strides_propagation_level3_sdfg() -> dace.SDFG: + """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" + sdfg = dace.SDFG(util.unique_name("level3")) + state = sdfg.add_state(is_start_block=True) + names = ["a3", "c3"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL3", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a3[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("c3[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + """Generates the level 2 SDFG (nested) SDFG for `test_strides_propagation()`. + + The function returns the level 2 SDFG and the NestedSDFG node that contains + the level 3 SDFG. + """ + sdfg = dace.SDFG(util.unique_name("level2")) + state = sdfg.add_state(is_start_block=True) + names = ["a2", "a2_alias", "b2", "c2"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL2_1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "compL2_2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("c2[__i0]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet("a2_alias[__i0]")}, + external_edges=True, + ) + + # This is the nested SDFG we have here. + sdfg_level3 = _make_strides_propagation_level3_sdfg() + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level3, + parent=sdfg, + inputs={"a3"}, + outputs={"c3"}, + symbol_mapping={s3: s3 for s3 in sdfg_level3.free_symbols}, + ) + + state.add_edge(state.add_access("a2"), None, nsdfg, "a3", dace.Memlet("a2[0:10]")) + state.add_edge(nsdfg, "c3", state.add_access("c2"), None, dace.Memlet("c2[0:10]")) + sdfg.validate() + + return sdfg, nsdfg + + +def _make_strides_propagation_level1_sdfg() -> ( + tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] +): + """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. + + Note that the SDFG is valid, but will be indeterminate. The only point of + this SDFG is to have a lot of different situations that have to be handled + for renaming. + + Returns: + A tuple of length three, with the following members: + - The top level SDFG. + - The NestedSDFG node that contains the level 2 SDFG (member of the top level SDFG). + - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). + """ + + sdfg = dace.SDFG(util.unique_name("level1")) + state = sdfg.add_state(is_start_block=True) + names = ["a1", "b1", "c1"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + sdfg_level2, nsdfg_level3 = _make_strides_propagation_level2_sdfg() + + nsdfg_level2: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg, + inputs={"a2", "c2"}, + outputs={"a2_alias", "b2", "c2"}, + symbol_mapping={s: s for s in sdfg_level2.free_symbols}, + ) + + for inner_name in nsdfg_level2.in_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + state.add_access(outer_name), + None, + nsdfg_level2, + inner_name, + dace.Memlet(f"{outer_name}[0:10]"), + ) + for inner_name in nsdfg_level2.out_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + nsdfg_level2, + inner_name, + state.add_access(outer_name), + None, + dace.Memlet(f"{outer_name}[0:10]"), + ) + + sdfg.validate() + + return sdfg, nsdfg_level2, nsdfg_level3 + + +def test_strides_propagation_use_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False) + sdfg_level1.validate() + + # Because `ignore_symbol_mapping=False` the strides of the data descriptor should + # not have changed. But the `symbol_mapping` has been updated for `a` and `b`. + # However, the symbols will only point one level above. + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + + if aname.startswith("c"): + target_symbol = f"{aname}_stride" + else: + target_symbol = f"{aname[0]}{level - 1}_stride" + + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) + sdfg_level1.validate() + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + target_symbol = f"{aname[0]}{level-1}_stride" + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + +def test_strides_propagation_ignore_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # After the propagation `a` and `b` should use the same stride (the one that + # it has on level 1, but `c` should still be level depending. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + original_stride = f"{aname}_stride" + if aname.startswith("c"): + exp_stride = f"{aname}_stride" + else: + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + sdfg_level1.validate() + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname[0]}1_stride" + original_stride = f"{aname}_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + # The symbol mapping must should not be updated. + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + +def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + array_names = ["a2", "b2"] + for name in array_names: + stride_sym = dace.symbol(f"{name}_stride", dtype=dace.uint64) + sdfg.add_symbol(stride_sym.name, stride_sym.dtype) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + array_names = ["a1", "b1"] + for name in array_names: + stride_sym1 = dace.symbol(f"{name}_1stride", dtype=dace.uint64) + stride_sym2 = dace.symbol(f"{name}_2stride", dtype=dace.int64) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + sdfg_level1.add_symbol(stride_sym2.name, stride_sym2.dtype) + stride_sym = stride_sym1 * stride_sym2 + sdfg_level1.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_dependent_symbol_nsdfg() + + for sym, sym_dtype in sdfg_level2.symbols.items(): + sdfg_level1.add_symbol(sym, sym_dtype) + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_dependent_symbol(): + sdfg_level1, nsdfg_level2 = _make_strides_propagation_dependent_symbol_sdfg() + sym1_dtype = dace.uint64 + sym2_dtype = dace.int64 + + # Ensure that the special symbols are not already present inside the nested SDFG. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in {fs.name for fs in adesc.strides[0].free_symbols} + assert sym not in nsdfg_level2.symbol_mapping + assert sym not in nsdfg_level2.sdfg.symbols + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + + # Now propagate `a1` and `b1`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # Now we check if the update has worked. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + adesc2 = nsdfg_level2.sdfg.arrays[aname.replace("1", "2")] + assert adesc2.strides == adesc.strides + + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in nsdfg_level2.symbol_mapping + assert nsdfg_level2.symbol_mapping[sym].name == sym + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + assert sym in nsdfg_level2.sdfg.symbols + assert nsdfg_level2.sdfg.symbols[sym] == dtype + + +def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + # NOTE: Both arrays have the same symbols used for strides. + array_names = ["a2", "b2"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=(stride_sym0, stride_sym1), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in1": dace.Memlet("a2[__i0, __i1]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + # NOTE: Both arrays use the same symbols as strides. + # Furthermore, they are the same as in the nested SDFG, i.e. they are shared. + array_names = ["a1", "b1"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg_level1.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=( + stride_sym0, + stride_sym1, + ), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_shared_symbols_nsdfg() + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10, 0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10, 0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_shared_symbols_sdfg(): + """Tests what happens if symbols are (unintentionally) shred between descriptor. + + This test looks rather artificial, but it is actually quite likely. Because + transients will most likely have the same shape and if the strides are not + set explicitly, which is the case, the strides will also be related to their + shape. This test explores the situation, where we can, for whatever reason, + only propagate the strides of one such data descriptor. + + Note: + If `ignore_symbol_mapping` is `False` then this test will fail. + This is because the `symbol_mapping` of the NestedSDFG will act on the + whole SDFG. Thus it will not only change the strides of `b` but as an + unintended side effect also the strides of `a`. + """ + + def ref(a1, b1): + for i in range(10): + for j in range(10): + b1[i, j] = a1[i, j] + 10.0 + + sdfg_level1, nsdfg_level2 = _make_strides_propagation_shared_symbols_sdfg() + + res_args = { + "a1": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + "b1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + # Now we change the strides of `b1`, and then we propagate the new strides + # into the nested SDFG. We want to keep (for whatever reasons) strides of `a1`. + stride_b1_sym0 = dace.symbol(f"__b1_stride_0", dtype=dace.uint64) + stride_b1_sym1 = dace.symbol(f"__b1_stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_b1_sym0.name, stride_b1_sym0.dtype) + sdfg_level1.add_symbol(stride_b1_sym1.name, stride_b1_sym1.dtype) + + desc_b1 = sdfg_level1.arrays["b1"] + desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of( + sdfg=sdfg_level1, + data_name="b1", + ) + + # Now we have to prepare the call arguments, i.e. adding the strides + itemsize = res_args["b1"].itemsize + res_args.update( + { + "__b1_stride_0": res_args["b1"].strides[0] // itemsize, + "__b1_stride_1": res_args["b1"].strides[1] // itemsize, + "__stride_0": res_args["a1"].strides[0] // itemsize, + "__stride_1": res_args["a1"].strides[1] // itemsize, + } + ) + ref(**ref_args) + sdfg_level1(**res_args) + assert np.allclose(ref_args["b1"], res_args["b1"])