diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b1092fafd0..d9cfa0ff48 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -62,7 +62,7 @@ repos:
## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"rev: '{version}' # version from constraints.txt")
##]]]
- rev: '23.9.1' # version from constraints.txt
+ rev: '23.11.0' # version from constraints.txt
##[[[end]]]
hooks:
- id: black
@@ -73,7 +73,7 @@ repos:
## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"rev: '{version}' # version from constraints.txt")
##]]]
- rev: '5.12.0' # version from constraints.txt
+ rev: '5.13.0' # version from constraints.txt
##[[[end]]]
hooks:
- id: isort
@@ -97,14 +97,14 @@ repos:
## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1]))
##]]]
- darglint==1.8.1
- - flake8-bugbear==23.9.16
- - flake8-builtins==2.1.0
+ - flake8-bugbear==23.12.2
+ - flake8-builtins==2.2.0
- flake8-debugger==4.1.2
- flake8-docstrings==1.7.0
- flake8-eradicate==1.5.0
- flake8-mutable==1.2.0
- flake8-pyproject==1.2.3
- - pygments==2.16.1
+ - pygments==2.17.2
##[[[end]]]
# - flake8-rst-docstrings # Disabled for now due to random false positives
exclude: |
@@ -146,9 +146,9 @@ repos:
## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"#========= FROM constraints.txt: v{version} =========")
##]]]
- #========= FROM constraints.txt: v1.5.1 =========
+ #========= FROM constraints.txt: v1.7.1 =========
##[[[end]]]
- rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date)
+ rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date)
hooks:
- id: mypy
additional_dependencies: # versions from constraints.txt
@@ -162,26 +162,26 @@ repos:
##]]]
- astunparse==1.6.3
- attrs==23.1.0
- - black==23.9.1
- - boltons==23.0.0
+ - black==23.11.0
+ - boltons==23.1.1
- cached-property==1.5.2
- click==8.1.7
- - cmake==3.27.5
+ - cmake==3.27.9
- cytoolz==0.12.2
- - deepdiff==6.5.0
+ - deepdiff==6.7.1
- devtools==0.12.2
- - frozendict==2.3.8
+ - frozendict==2.3.10
- gridtools-cpp==2.3.1
- - importlib-resources==6.0.1
+ - importlib-resources==6.1.1
- jinja2==3.1.2
- - lark==1.1.7
- - mako==1.2.4
- - nanobind==1.5.2
- - ninja==1.11.1
+ - lark==1.1.8
+ - mako==1.3.0
+ - nanobind==1.8.0
+ - ninja==1.11.1.1
- numpy==1.24.4
- - packaging==23.1
+ - packaging==23.2
- pybind11==2.11.1
- - setuptools==68.2.2
+ - setuptools==69.0.2
- tabulate==0.9.0
- typing-extensions==4.5.0
- xxhash==3.0.0
diff --git a/constraints.txt b/constraints.txt
index b334851af1..81abd64c6e 100644
--- a/constraints.txt
+++ b/constraints.txt
@@ -6,124 +6,136 @@
#
aenum==3.1.15 # via dace
alabaster==0.7.13 # via sphinx
-asttokens==2.4.0 # via devtools
+asttokens==2.4.1 # via devtools
astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml)
attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing
-babel==2.12.1 # via sphinx
-black==23.9.1 # via gt4py (pyproject.toml)
-blinker==1.6.2 # via flask
-boltons==23.0.0 # via gt4py (pyproject.toml)
+babel==2.13.1 # via sphinx
+black==23.11.0 # via gt4py (pyproject.toml)
+blinker==1.7.0 # via flask
+boltons==23.1.1 # via gt4py (pyproject.toml)
build==1.0.3 # via pip-tools
cached-property==1.5.2 # via gt4py (pyproject.toml)
-cachetools==5.3.1 # via tox
-certifi==2023.7.22 # via requests
-cffi==1.15.1 # via cryptography
+cachetools==5.3.2 # via tox
+cerberus==1.3.5 # via plette
+certifi==2023.11.17 # via requests
+cffi==1.16.0 # via cryptography
cfgv==3.4.0 # via pre-commit
chardet==5.2.0 # via tox
-charset-normalizer==3.2.0 # via requests
-clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
+charset-normalizer==3.3.2 # via requests
+clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools
-cmake==3.27.5 # via gt4py (pyproject.toml)
+cmake==3.27.9 # via dace, gt4py (pyproject.toml)
cogapp==3.3.0 # via -r requirements-dev.in
colorama==0.4.6 # via tox
-coverage==7.3.1 # via -r requirements-dev.in, pytest-cov
-cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis
+coverage==7.3.2 # via -r requirements-dev.in, pytest-cov
+cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis
cytoolz==0.12.2 # via gt4py (pyproject.toml)
-dace==0.14.4 # via gt4py (pyproject.toml)
+dace==0.15.1 # via gt4py (pyproject.toml)
darglint==1.8.1 # via -r requirements-dev.in
-deepdiff==6.5.0 # via gt4py (pyproject.toml)
+deepdiff==6.7.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.7 # via dace
-distlib==0.3.7 # via virtualenv
-docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
+distlib==0.3.7 # via requirementslib, virtualenv
+distro==1.8.0 # via scikit-build
+docopt==0.6.2 # via pipreqs
+docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
eradicate==2.3.0 # via flake8-eradicate
-exceptiongroup==1.1.3 # via hypothesis, pytest
+exceptiongroup==1.2.0 # via hypothesis, pytest
execnet==2.0.2 # via pytest-cache, pytest-xdist
-executing==1.2.0 # via devtools
+executing==2.0.1 # via devtools
factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy
-faker==19.6.1 # via factory-boy
-fastjsonschema==2.18.0 # via nbformat
-filelock==3.12.4 # via tox, virtualenv
+faker==20.1.0 # via factory-boy
+fastjsonschema==2.19.0 # via nbformat
+filelock==3.13.1 # via tox, virtualenv
flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings
-flake8-bugbear==23.9.16 # via -r requirements-dev.in
-flake8-builtins==2.1.0 # via -r requirements-dev.in
+flake8-bugbear==23.12.2 # via -r requirements-dev.in
+flake8-builtins==2.2.0 # via -r requirements-dev.in
flake8-debugger==4.1.2 # via -r requirements-dev.in
flake8-docstrings==1.7.0 # via -r requirements-dev.in
flake8-eradicate==1.5.0 # via -r requirements-dev.in
flake8-mutable==1.2.0 # via -r requirements-dev.in
flake8-pyproject==1.2.3 # via -r requirements-dev.in
flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in
-flask==2.3.3 # via dace
-frozendict==2.3.8 # via gt4py (pyproject.toml)
+flask==3.0.0 # via dace
+fparser==0.1.3 # via dace
+frozendict==2.3.10 # via gt4py (pyproject.toml)
gridtools-cpp==2.3.1 # via gt4py (pyproject.toml)
-hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml)
-identify==2.5.29 # via pre-commit
-idna==3.4 # via requests
+hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml)
+identify==2.5.33 # via pre-commit
+idna==3.6 # via requests
imagesize==1.4.1 # via sphinx
-importlib-metadata==6.8.0 # via build, flask, sphinx
-importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
+importlib-metadata==7.0.0 # via build, flask, fparser, sphinx
+importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
inflection==0.5.1 # via pytest-factoryboy
iniconfig==2.0.0 # via pytest
-isort==5.12.0 # via -r requirements-dev.in
+isort==5.13.0 # via -r requirements-dev.in
itsdangerous==2.1.2 # via flask
jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx
-jsonschema==4.19.0 # via nbformat
-jsonschema-specifications==2023.7.1 # via jsonschema
-jupyter-core==5.3.1 # via nbformat
-jupytext==1.15.2 # via -r requirements-dev.in
-lark==1.1.7 # via gt4py (pyproject.toml)
-mako==1.2.4 # via gt4py (pyproject.toml)
+jsonschema==4.20.0 # via nbformat
+jsonschema-specifications==2023.11.2 # via jsonschema
+jupyter-core==5.5.0 # via nbformat
+jupytext==1.16.0 # via -r requirements-dev.in
+lark==1.1.8 # via gt4py (pyproject.toml)
+mako==1.3.0 # via gt4py (pyproject.toml)
markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins
markupsafe==2.1.3 # via jinja2, mako, werkzeug
mccabe==0.7.0 # via flake8
mdit-py-plugins==0.4.0 # via jupytext
mdurl==0.1.2 # via markdown-it-py
mpmath==1.3.0 # via sympy
-mypy==1.5.1 # via -r requirements-dev.in
+mypy==1.7.1 # via -r requirements-dev.in
mypy-extensions==1.0.0 # via black, mypy
-nanobind==1.5.2 # via gt4py (pyproject.toml)
+nanobind==1.8.0 # via gt4py (pyproject.toml)
nbformat==5.9.2 # via jupytext
networkx==3.1 # via dace
-ninja==1.11.1 # via gt4py (pyproject.toml)
+ninja==1.11.1.1 # via gt4py (pyproject.toml)
nodeenv==1.8.0 # via pre-commit
numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client
ordered-set==4.1.0 # via deepdiff
-packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox
-pathspec==0.11.2 # via black
+packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox
+pathspec==0.12.1 # via black
+pep517==0.13.1 # via requirementslib
+pip-api==0.0.30 # via isort
pip-tools==7.3.0 # via -r requirements-dev.in
-pipdeptree==2.13.0 # via -r requirements-dev.in
+pipdeptree==2.13.1 # via -r requirements-dev.in
+pipreqs==0.4.13 # via isort
pkgutil-resolve-name==1.3.10 # via jsonschema
-platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv
+platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv
+plette==0.4.4 # via requirementslib
pluggy==1.3.0 # via pytest, tox
ply==3.11 # via dace
-pre-commit==3.4.0 # via -r requirements-dev.in
-psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist
+pre-commit==3.5.0 # via -r requirements-dev.in
+psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist
pybind11==2.11.1 # via gt4py (pyproject.toml)
-pycodestyle==2.11.0 # via flake8, flake8-debugger
+pycodestyle==2.11.1 # via flake8, flake8-debugger
pycparser==2.21 # via cffi
+pydantic==1.10.13 # via requirementslib
pydocstyle==6.3.0 # via flake8-docstrings
pyflakes==3.1.0 # via flake8
-pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
+pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
pyproject-api==1.6.1 # via tox
pyproject-hooks==1.0.0 # via build
-pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
+pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
pytest-cache==1.0 # via -r requirements-dev.in
pytest-cov==4.1.0 # via -r requirements-dev.in
-pytest-factoryboy==2.5.1 # via -r requirements-dev.in
-pytest-xdist==3.3.1 # via -r requirements-dev.in
+pytest-factoryboy==2.6.0 # via -r requirements-dev.in
+pytest-xdist==3.5.0 # via -r requirements-dev.in
python-dateutil==2.8.2 # via faker
pytz==2023.3.post1 # via babel
pyyaml==6.0.1 # via dace, jupytext, pre-commit
-referencing==0.30.2 # via jsonschema, jsonschema-specifications
-requests==2.31.0 # via dace, sphinx
+referencing==0.32.0 # via jsonschema, jsonschema-specifications
+requests==2.31.0 # via dace, requirementslib, sphinx, yarg
+requirementslib==3.0.0 # via isort
restructuredtext-lint==1.4.0 # via flake8-rst-docstrings
-rpds-py==0.10.3 # via jsonschema, referencing
-ruff==0.0.290 # via -r requirements-dev.in
+rpds-py==0.13.2 # via jsonschema, referencing
+ruff==0.1.7 # via -r requirements-dev.in
+scikit-build==0.17.6 # via dace
+setuptools-scm==8.0.4 # via fparser
six==1.16.0 # via asttokens, astunparse, python-dateutil
snowballstemmer==2.2.0 # via pydocstyle, sphinx
sortedcontainers==2.4.0 # via hypothesis
sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery
-sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in
+sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in
sphinxcontrib-applehelp==1.0.4 # via sphinx
sphinxcontrib-devhelp==1.0.2 # via sphinx
sphinxcontrib-htmlhelp==2.0.1 # via sphinx
@@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme
sphinxcontrib-jsmath==1.0.1 # via sphinx
sphinxcontrib-qthelp==1.0.3 # via sphinx
sphinxcontrib-serializinghtml==1.1.5 # via sphinx
-sympy==1.12 # via dace, gt4py (pyproject.toml)
+sympy==1.9 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
toml==0.10.2 # via jupytext
-tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox
+tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox
+tomlkit==0.12.3 # via plette, requirementslib
toolz==0.12.0 # via cytoolz
-tox==4.11.3 # via -r requirements-dev.in
-traitlets==5.10.0 # via jupyter-core, nbformat
+tox==4.11.4 # via -r requirements-dev.in
+traitlets==5.14.0 # via jupyter-core, nbformat
types-aiofiles==23.2.0.0 # via types-all
types-all==1.0.0 # via -r requirements-dev.in
types-annoy==1.17.8.4 # via types-all
types-atomicwrites==1.4.5.1 # via types-all
types-backports==0.1.3 # via types-all
types-backports-abc==0.5.2 # via types-all
-types-bleach==6.0.0.4 # via types-all
+types-bleach==6.1.0.1 # via types-all
types-boto==2.49.18.9 # via types-all
-types-cachetools==5.3.0.6 # via types-all
+types-cachetools==5.3.0.7 # via types-all
types-certifi==2021.10.8.3 # via types-all
-types-cffi==1.15.1.15 # via types-jack-client
+types-cffi==1.16.0.0 # via types-jack-client
types-characteristic==14.3.7 # via types-all
types-chardet==5.0.4.6 # via types-all
types-click==7.1.8 # via types-all, types-flask
-types-click-spinner==0.1.13.5 # via types-all
+types-click-spinner==0.1.13.6 # via types-all
types-colorama==0.4.15.12 # via types-all
types-contextvars==2.4.7.3 # via types-all
-types-croniter==1.4.0.1 # via types-all
+types-croniter==2.0.0.0 # via types-all
types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt
types-dataclasses==0.6.6 # via types-all
types-dateparser==1.1.4.10 # via types-all
@@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all
types-geoip2==3.0.0 # via types-all
types-ipaddress==1.0.8 # via types-all, types-maxminddb
types-itsdangerous==1.1.6 # via types-all
-types-jack-client==0.5.10.9 # via types-all
+types-jack-client==0.5.10.10 # via types-all
types-jinja2==2.11.9 # via types-all, types-flask
types-kazoo==0.1.3 # via types-all
-types-markdown==3.4.2.10 # via types-all
+types-markdown==3.5.0.3 # via types-all
types-markupsafe==1.1.10 # via types-all, types-jinja2
types-maxminddb==1.5.0 # via types-all, types-geoip2
-types-mock==5.1.0.2 # via types-all
+types-mock==5.1.0.3 # via types-all
types-mypy-extensions==1.0.0.5 # via types-all
types-nmap==0.1.6 # via types-all
types-openssl-python==0.1.3 # via types-all
types-orjson==3.6.2 # via types-all
-types-paramiko==3.3.0.0 # via types-all, types-pysftp
+types-paramiko==3.3.0.2 # via types-all, types-pysftp
types-pathlib2==2.3.0 # via types-all
-types-pillow==10.0.0.3 # via types-all
+types-pillow==10.1.0.2 # via types-all
types-pkg-resources==0.1.3 # via types-all
types-polib==1.2.0.1 # via types-all
-types-protobuf==4.24.0.1 # via types-all
+types-protobuf==4.24.0.4 # via types-all
types-pyaudio==0.2.16.7 # via types-all
types-pycurl==7.45.2.5 # via types-all
types-pyfarmhash==0.3.1.2 # via types-all
types-pyjwt==1.7.1 # via types-all
types-pymssql==2.1.0 # via types-all
types-pymysql==1.1.0.1 # via types-all
-types-pyopenssl==23.2.0.2 # via types-redis
+types-pyopenssl==23.3.0.0 # via types-redis
types-pyrfc3339==1.1.1.5 # via types-all
types-pysftp==0.2.17.6 # via types-all
types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange
types-python-gflags==3.1.7.3 # via types-all
types-python-slugify==8.0.0.3 # via types-all
-types-pytz==2023.3.1.0 # via types-all, types-tzlocal
+types-pytz==2023.3.1.1 # via types-all, types-tzlocal
types-pyvmomi==8.0.0.6 # via types-all
-types-pyyaml==6.0.12.11 # via types-all
-types-redis==4.6.0.6 # via types-all
-types-requests==2.31.0.2 # via types-all
+types-pyyaml==6.0.12.12 # via types-all
+types-redis==4.6.0.11 # via types-all
+types-requests==2.31.0.10 # via types-all
types-retry==0.9.9.4 # via types-all
types-routes==2.5.0 # via types-all
types-scribe==2.0.0 # via types-all
-types-setuptools==68.2.0.0 # via types-cffi
+types-setuptools==69.0.0.0 # via types-cffi
types-simplejson==3.19.0.2 # via types-all
types-singledispatch==4.1.0.0 # via types-all
types-six==1.16.21.9 # via types-all
@@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all
types-toml==0.10.8.7 # via types-all
types-tornado==5.1.1 # via types-all
types-typed-ast==1.5.8.7 # via types-all
-types-tzlocal==5.0.1.1 # via types-all
+types-tzlocal==5.1.0.1 # via types-all
types-ujson==5.8.0.1 # via types-all
-types-urllib3==1.26.25.14 # via types-requests
types-waitress==2.1.4.9 # via types-all
types-werkzeug==1.0.9 # via types-all, types-flask
types-xxhash==3.0.5.2 # via types-all
-typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy
-urllib3==2.0.4 # via requests
-virtualenv==20.24.5 # via pre-commit, tox
-websockets==11.0.3 # via dace
-werkzeug==2.3.7 # via flask
-wheel==0.41.2 # via astunparse, pip-tools
+typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm
+urllib3==2.1.0 # via requests, types-requests
+virtualenv==20.25.0 # via pre-commit, tox
+websockets==12.0 # via dace
+werkzeug==3.0.1 # via flask
+wheel==0.42.0 # via astunparse, pip-tools, scikit-build
xxhash==3.0.0 # via gt4py (pyproject.toml)
-zipp==3.16.2 # via importlib-metadata, importlib-resources
+yarg==0.1.9 # via pipreqs
+zipp==3.17.0 # via importlib-metadata, importlib-resources
# The following packages are considered to be unsafe in a requirements file:
-pip==23.2.1 # via pip-tools
-setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools
+pip==23.3.1 # via pip-api, pip-tools, requirementslib
+setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm
diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb
new file mode 100644
index 0000000000..cb80122570
--- /dev/null
+++ b/examples/lap_cartesian_vs_next.ipynb
@@ -0,0 +1,189 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "GT4Py - GridTools for Python\n",
+ "\n",
+ "Copyright (c) 2014-2023, ETH Zurich\n",
+ "All rights reserved.\n",
+ "\n",
+ "This file is part the GT4Py project and the GridTools framework.\n",
+ "GT4Py is free software: you can redistribute it and/or modify it under\n",
+ "the terms of the GNU General Public License as published by the\n",
+ "Free Software Foundation, either version 3 of the License, or any later\n",
+ "version. See the LICENSE.txt file at the top-level directory of this\n",
+ "distribution for a copy of the license or check .\n",
+ "\n",
+ "SPDX-License-Identifier: GPL-3.0-or-later"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Demonstrates gt4py.cartesian with gt4py.next compatibility"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "nx = 32\n",
+ "ny = 32\n",
+ "nz = 1\n",
+ "dtype = np.float64"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Storages\n",
+ "--\n",
+ "\n",
+ "We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import gt4py.next as gtx\n",
+ "\n",
+ "allocator = gtx.itir_embedded # should match the executor\n",
+ "# allocator = gtx.gtfn_cpu\n",
+ "# allocator = gtx.gtfn_gpu\n",
+ "\n",
+ "# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n",
+ "I = gtx.Dimension(\"I\")\n",
+ "J = gtx.Dimension(\"J\")\n",
+ "K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n",
+ "\n",
+ "domain = gtx.domain({I: nx, J: ny, K: nz})\n",
+ "\n",
+ "inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n",
+ "out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n",
+ "out_next = gtx.zeros(domain, dtype, allocator=allocator)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "gt4py.cartesian\n",
+ "--"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import gt4py.cartesian.gtscript as gtscript\n",
+ "\n",
+ "cartesian_backend = \"numpy\"\n",
+ "# cartesian_backend = \"gt:cpu_ifirst\"\n",
+ "# cartesian_backend = \"gt:gpu\"\n",
+ "\n",
+ "@gtscript.stencil(backend=cartesian_backend)\n",
+ "def lap_cartesian(\n",
+ " inp: gtscript.Field[dtype],\n",
+ " out: gtscript.Field[dtype],\n",
+ "):\n",
+ " with computation(PARALLEL), interval(...):\n",
+ " out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n",
+ "\n",
+ "lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from gt4py.next import Field\n",
+ "\n",
+ "next_backend = gtx.itir_embedded\n",
+ "# next_backend = gtx.gtfn_cpu\n",
+ "# next_backend = gtx.gtfn_gpu\n",
+ "\n",
+ "Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n",
+ "Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n",
+ "\n",
+ "@gtx.field_operator\n",
+ "def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n",
+ " return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n",
+ "\n",
+ "@gtx.program(backend=next_backend)\n",
+ "def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n",
+ " lap_next(inp, out=out[1:-1, 1:-1, :])\n",
+ "\n",
+ "lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt
index 17709206a0..fd7724bac9 100644
--- a/min-extra-requirements-test.txt
+++ b/min-extra-requirements-test.txt
@@ -25,7 +25,7 @@ cmake==3.22
cogapp==3.3
coverage[toml]==5.0
cytoolz==0.12.0
-dace==0.14.2
+dace==0.15.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
@@ -70,7 +70,7 @@ scipy==1.7.2
setuptools==65.5.0
sphinx==4.4
sphinx_rtd_theme==1.0
-sympy==1.7
+sympy==1.9
tabulate==0.8.10
tomli==2.0.1
tox==3.2.0
diff --git a/pyproject.toml b/pyproject.toml
index 5d7a2f2cb6..675bdae9d0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,15 +69,15 @@ requires-python = '>=3.8'
cuda = ['cupy>=12.0']
cuda11x = ['cupy-cuda11x>=12.0']
cuda12x = ['cupy-cuda12x>=12.0']
-dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7']
+dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9']
formatting = ['clang-format>=9.0']
# Always add all extra packages to 'full' for a simple full gt4py installation
full = [
'clang-format>=9.0',
- 'dace>=0.14.2,<0.15',
+ 'dace>=0.15.1,<0.16',
'hypothesis>=6.0.0',
'pytest>=7.0',
- 'sympy>=1.7',
+ 'sympy>=1.9',
'scipy>=1.7.2',
'jax[cpu]>=0.4.13'
]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index d6dcc12d21..0fa523866f 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -6,124 +6,136 @@
#
aenum==3.1.15 # via dace
alabaster==0.7.13 # via sphinx
-asttokens==2.4.0 # via devtools
+asttokens==2.4.1 # via devtools
astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml)
attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing
-babel==2.12.1 # via sphinx
-black==23.9.1 # via gt4py (pyproject.toml)
-blinker==1.6.2 # via flask
-boltons==23.0.0 # via gt4py (pyproject.toml)
+babel==2.13.1 # via sphinx
+black==23.11.0 # via gt4py (pyproject.toml)
+blinker==1.7.0 # via flask
+boltons==23.1.1 # via gt4py (pyproject.toml)
build==1.0.3 # via pip-tools
cached-property==1.5.2 # via gt4py (pyproject.toml)
-cachetools==5.3.1 # via tox
-certifi==2023.7.22 # via requests
-cffi==1.15.1 # via cryptography
+cachetools==5.3.2 # via tox
+cerberus==1.3.5 # via plette
+certifi==2023.11.17 # via requests
+cffi==1.16.0 # via cryptography
cfgv==3.4.0 # via pre-commit
chardet==5.2.0 # via tox
-charset-normalizer==3.2.0 # via requests
-clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
+charset-normalizer==3.3.2 # via requests
+clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools
-cmake==3.27.5 # via gt4py (pyproject.toml)
+cmake==3.27.9 # via dace, gt4py (pyproject.toml)
cogapp==3.3.0 # via -r requirements-dev.in
colorama==0.4.6 # via tox
-coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov
-cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis
+coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov
+cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis
cytoolz==0.12.2 # via gt4py (pyproject.toml)
-dace==0.14.4 # via gt4py (pyproject.toml)
+dace==0.15.1 # via gt4py (pyproject.toml)
darglint==1.8.1 # via -r requirements-dev.in
-deepdiff==6.5.0 # via gt4py (pyproject.toml)
+deepdiff==6.7.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.7 # via dace
-distlib==0.3.7 # via virtualenv
-docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
+distlib==0.3.7 # via requirementslib, virtualenv
+distro==1.8.0 # via scikit-build
+docopt==0.6.2 # via pipreqs
+docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
eradicate==2.3.0 # via flake8-eradicate
-exceptiongroup==1.1.3 # via hypothesis, pytest
+exceptiongroup==1.2.0 # via hypothesis, pytest
execnet==2.0.2 # via pytest-cache, pytest-xdist
-executing==1.2.0 # via devtools
+executing==2.0.1 # via devtools
factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy
-faker==19.6.1 # via factory-boy
-fastjsonschema==2.18.0 # via nbformat
-filelock==3.12.4 # via tox, virtualenv
+faker==20.1.0 # via factory-boy
+fastjsonschema==2.19.0 # via nbformat
+filelock==3.13.1 # via tox, virtualenv
flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings
-flake8-bugbear==23.9.16 # via -r requirements-dev.in
-flake8-builtins==2.1.0 # via -r requirements-dev.in
+flake8-bugbear==23.12.2 # via -r requirements-dev.in
+flake8-builtins==2.2.0 # via -r requirements-dev.in
flake8-debugger==4.1.2 # via -r requirements-dev.in
flake8-docstrings==1.7.0 # via -r requirements-dev.in
flake8-eradicate==1.5.0 # via -r requirements-dev.in
flake8-mutable==1.2.0 # via -r requirements-dev.in
flake8-pyproject==1.2.3 # via -r requirements-dev.in
flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in
-flask==2.3.3 # via dace
-frozendict==2.3.8 # via gt4py (pyproject.toml)
+flask==3.0.0 # via dace
+fparser==0.1.3 # via dace
+frozendict==2.3.10 # via gt4py (pyproject.toml)
gridtools-cpp==2.3.1 # via gt4py (pyproject.toml)
-hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml)
-identify==2.5.29 # via pre-commit
-idna==3.4 # via requests
+hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml)
+identify==2.5.33 # via pre-commit
+idna==3.6 # via requests
imagesize==1.4.1 # via sphinx
-importlib-metadata==6.8.0 # via build, flask, sphinx
-importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
+importlib-metadata==7.0.0 # via build, flask, fparser, sphinx
+importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
inflection==0.5.1 # via pytest-factoryboy
iniconfig==2.0.0 # via pytest
-isort==5.12.0 # via -r requirements-dev.in
+isort==5.13.0 # via -r requirements-dev.in
itsdangerous==2.1.2 # via flask
jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx
-jsonschema==4.19.0 # via nbformat
-jsonschema-specifications==2023.7.1 # via jsonschema
-jupyter-core==5.3.1 # via nbformat
-jupytext==1.15.2 # via -r requirements-dev.in
-lark==1.1.7 # via gt4py (pyproject.toml)
-mako==1.2.4 # via gt4py (pyproject.toml)
+jsonschema==4.20.0 # via nbformat
+jsonschema-specifications==2023.11.2 # via jsonschema
+jupyter-core==5.5.0 # via nbformat
+jupytext==1.16.0 # via -r requirements-dev.in
+lark==1.1.8 # via gt4py (pyproject.toml)
+mako==1.3.0 # via gt4py (pyproject.toml)
markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins
markupsafe==2.1.3 # via jinja2, mako, werkzeug
mccabe==0.7.0 # via flake8
mdit-py-plugins==0.4.0 # via jupytext
mdurl==0.1.2 # via markdown-it-py
mpmath==1.3.0 # via sympy
-mypy==1.5.1 # via -r requirements-dev.in
+mypy==1.7.1 # via -r requirements-dev.in
mypy-extensions==1.0.0 # via black, mypy
-nanobind==1.5.2 # via gt4py (pyproject.toml)
+nanobind==1.8.0 # via gt4py (pyproject.toml)
nbformat==5.9.2 # via jupytext
networkx==3.1 # via dace
-ninja==1.11.1 # via gt4py (pyproject.toml)
+ninja==1.11.1.1 # via gt4py (pyproject.toml)
nodeenv==1.8.0 # via pre-commit
numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client
ordered-set==4.1.0 # via deepdiff
-packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox
-pathspec==0.11.2 # via black
+packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox
+pathspec==0.12.1 # via black
+pep517==0.13.1 # via requirementslib
+pip-api==0.0.30 # via isort
pip-tools==7.3.0 # via -r requirements-dev.in
-pipdeptree==2.13.0 # via -r requirements-dev.in
+pipdeptree==2.13.1 # via -r requirements-dev.in
+pipreqs==0.4.13 # via isort
pkgutil-resolve-name==1.3.10 # via jsonschema
-platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv
+platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv
+plette[validation]==0.4.4 # via requirementslib
pluggy==1.3.0 # via pytest, tox
ply==3.11 # via dace
-pre-commit==3.4.0 # via -r requirements-dev.in
-psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist
+pre-commit==3.5.0 # via -r requirements-dev.in
+psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist
pybind11==2.11.1 # via gt4py (pyproject.toml)
-pycodestyle==2.11.0 # via flake8, flake8-debugger
+pycodestyle==2.11.1 # via flake8, flake8-debugger
pycparser==2.21 # via cffi
+pydantic==1.10.13 # via requirementslib
pydocstyle==6.3.0 # via flake8-docstrings
pyflakes==3.1.0 # via flake8
-pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
+pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
pyproject-api==1.6.1 # via tox
pyproject-hooks==1.0.0 # via build
-pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
+pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
pytest-cache==1.0 # via -r requirements-dev.in
pytest-cov==4.1.0 # via -r requirements-dev.in
-pytest-factoryboy==2.5.1 # via -r requirements-dev.in
-pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in
+pytest-factoryboy==2.6.0 # via -r requirements-dev.in
+pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in
python-dateutil==2.8.2 # via faker
pytz==2023.3.post1 # via babel
pyyaml==6.0.1 # via dace, jupytext, pre-commit
-referencing==0.30.2 # via jsonschema, jsonschema-specifications
-requests==2.31.0 # via dace, sphinx
+referencing==0.32.0 # via jsonschema, jsonschema-specifications
+requests==2.31.0 # via dace, requirementslib, sphinx, yarg
+requirementslib==3.0.0 # via isort
restructuredtext-lint==1.4.0 # via flake8-rst-docstrings
-rpds-py==0.10.3 # via jsonschema, referencing
-ruff==0.0.290 # via -r requirements-dev.in
+rpds-py==0.13.2 # via jsonschema, referencing
+ruff==0.1.7 # via -r requirements-dev.in
+scikit-build==0.17.6 # via dace
+setuptools-scm==8.0.4 # via fparser
six==1.16.0 # via asttokens, astunparse, python-dateutil
snowballstemmer==2.2.0 # via pydocstyle, sphinx
sortedcontainers==2.4.0 # via hypothesis
sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery
-sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in
+sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in
sphinxcontrib-applehelp==1.0.4 # via sphinx
sphinxcontrib-devhelp==1.0.2 # via sphinx
sphinxcontrib-htmlhelp==2.0.1 # via sphinx
@@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme
sphinxcontrib-jsmath==1.0.1 # via sphinx
sphinxcontrib-qthelp==1.0.3 # via sphinx
sphinxcontrib-serializinghtml==1.1.5 # via sphinx
-sympy==1.12 # via dace, gt4py (pyproject.toml)
+sympy==1.9 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
toml==0.10.2 # via jupytext
-tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox
+tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox
+tomlkit==0.12.3 # via plette, requirementslib
toolz==0.12.0 # via cytoolz
-tox==4.11.3 # via -r requirements-dev.in
-traitlets==5.10.0 # via jupyter-core, nbformat
+tox==4.11.4 # via -r requirements-dev.in
+traitlets==5.14.0 # via jupyter-core, nbformat
types-aiofiles==23.2.0.0 # via types-all
types-all==1.0.0 # via -r requirements-dev.in
types-annoy==1.17.8.4 # via types-all
types-atomicwrites==1.4.5.1 # via types-all
types-backports==0.1.3 # via types-all
types-backports-abc==0.5.2 # via types-all
-types-bleach==6.0.0.4 # via types-all
+types-bleach==6.1.0.1 # via types-all
types-boto==2.49.18.9 # via types-all
-types-cachetools==5.3.0.6 # via types-all
+types-cachetools==5.3.0.7 # via types-all
types-certifi==2021.10.8.3 # via types-all
-types-cffi==1.15.1.15 # via types-jack-client
+types-cffi==1.16.0.0 # via types-jack-client
types-characteristic==14.3.7 # via types-all
types-chardet==5.0.4.6 # via types-all
types-click==7.1.8 # via types-all, types-flask
-types-click-spinner==0.1.13.5 # via types-all
+types-click-spinner==0.1.13.6 # via types-all
types-colorama==0.4.15.12 # via types-all
types-contextvars==2.4.7.3 # via types-all
-types-croniter==1.4.0.1 # via types-all
+types-croniter==2.0.0.0 # via types-all
types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt
types-dataclasses==0.6.6 # via types-all
types-dateparser==1.1.4.10 # via types-all
@@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all
types-geoip2==3.0.0 # via types-all
types-ipaddress==1.0.8 # via types-all, types-maxminddb
types-itsdangerous==1.1.6 # via types-all
-types-jack-client==0.5.10.9 # via types-all
+types-jack-client==0.5.10.10 # via types-all
types-jinja2==2.11.9 # via types-all, types-flask
types-kazoo==0.1.3 # via types-all
-types-markdown==3.4.2.10 # via types-all
+types-markdown==3.5.0.3 # via types-all
types-markupsafe==1.1.10 # via types-all, types-jinja2
types-maxminddb==1.5.0 # via types-all, types-geoip2
-types-mock==5.1.0.2 # via types-all
+types-mock==5.1.0.3 # via types-all
types-mypy-extensions==1.0.0.5 # via types-all
types-nmap==0.1.6 # via types-all
types-openssl-python==0.1.3 # via types-all
types-orjson==3.6.2 # via types-all
-types-paramiko==3.3.0.0 # via types-all, types-pysftp
+types-paramiko==3.3.0.2 # via types-all, types-pysftp
types-pathlib2==2.3.0 # via types-all
-types-pillow==10.0.0.3 # via types-all
+types-pillow==10.1.0.2 # via types-all
types-pkg-resources==0.1.3 # via types-all
types-polib==1.2.0.1 # via types-all
-types-protobuf==4.24.0.1 # via types-all
+types-protobuf==4.24.0.4 # via types-all
types-pyaudio==0.2.16.7 # via types-all
types-pycurl==7.45.2.5 # via types-all
types-pyfarmhash==0.3.1.2 # via types-all
types-pyjwt==1.7.1 # via types-all
types-pymssql==2.1.0 # via types-all
types-pymysql==1.1.0.1 # via types-all
-types-pyopenssl==23.2.0.2 # via types-redis
+types-pyopenssl==23.3.0.0 # via types-redis
types-pyrfc3339==1.1.1.5 # via types-all
types-pysftp==0.2.17.6 # via types-all
types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange
types-python-gflags==3.1.7.3 # via types-all
types-python-slugify==8.0.0.3 # via types-all
-types-pytz==2023.3.1.0 # via types-all, types-tzlocal
+types-pytz==2023.3.1.1 # via types-all, types-tzlocal
types-pyvmomi==8.0.0.6 # via types-all
-types-pyyaml==6.0.12.11 # via types-all
-types-redis==4.6.0.6 # via types-all
-types-requests==2.31.0.2 # via types-all
+types-pyyaml==6.0.12.12 # via types-all
+types-redis==4.6.0.11 # via types-all
+types-requests==2.31.0.10 # via types-all
types-retry==0.9.9.4 # via types-all
types-routes==2.5.0 # via types-all
types-scribe==2.0.0 # via types-all
-types-setuptools==68.2.0.0 # via types-cffi
+types-setuptools==69.0.0.0 # via types-cffi
types-simplejson==3.19.0.2 # via types-all
types-singledispatch==4.1.0.0 # via types-all
types-six==1.16.21.9 # via types-all
@@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all
types-toml==0.10.8.7 # via types-all
types-tornado==5.1.1 # via types-all
types-typed-ast==1.5.8.7 # via types-all
-types-tzlocal==5.0.1.1 # via types-all
+types-tzlocal==5.1.0.1 # via types-all
types-ujson==5.8.0.1 # via types-all
-types-urllib3==1.26.25.14 # via types-requests
types-waitress==2.1.4.9 # via types-all
types-werkzeug==1.0.9 # via types-all, types-flask
types-xxhash==3.0.5.2 # via types-all
-typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy
-urllib3==2.0.4 # via requests
-virtualenv==20.24.5 # via pre-commit, tox
-websockets==11.0.3 # via dace
-werkzeug==2.3.7 # via flask
-wheel==0.41.2 # via astunparse, pip-tools
+typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm
+urllib3==2.1.0 # via requests, types-requests
+virtualenv==20.25.0 # via pre-commit, tox
+websockets==12.0 # via dace
+werkzeug==3.0.1 # via flask
+wheel==0.42.0 # via astunparse, pip-tools, scikit-build
xxhash==3.0.0 # via gt4py (pyproject.toml)
-zipp==3.16.2 # via importlib-metadata, importlib-resources
+yarg==0.1.9 # via pipreqs
+zipp==3.17.0 # via importlib-metadata, importlib-resources
# The following packages are considered to be unsafe in a requirements file:
-pip==23.2.1 # via pip-tools
-setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools
+pip==23.3.1 # via pip-api, pip-tools, requirementslib
+setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm
diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py
index 7d255de142..c28c5cf2d6 100644
--- a/src/gt4py/__init__.py
+++ b/src/gt4py/__init__.py
@@ -33,6 +33,6 @@
if _sys.version_info >= (3, 10):
- from . import next
+ from . import next # noqa: A004
__all__ += ["next"]
diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py
index b1e559a41e..5dae025acb 100644
--- a/src/gt4py/cartesian/backend/dace_backend.py
+++ b/src/gt4py/cartesian/backend/dace_backend.py
@@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S
omp_threads = ""
omp_header = ""
- # Backward compatible state struct name change in DaCe >=0.15.x
- try:
- dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix")
- except (KeyError, TypeError):
- dace_state_suffix = "_t" # old structure name
-
interface = cls.template.definition.render(
name=sdfg.name,
backend_specifics=omp_threads,
@@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S
functor_args=self.generate_functor_args(sdfg),
tmp_allocs=self.generate_tmp_allocs(sdfg),
allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique",
- state_suffix=dace_state_suffix,
+ state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"),
)
generated_code = textwrap.dedent(
f"""#include
diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
index db276a48b9..48b129fa87 100644
--- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
+++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
@@ -30,6 +30,7 @@
compute_dcir_access_infos,
flatten_list,
get_tasklet_symbol,
+ make_dace_subset,
union_inout_memlets,
union_node_grid_subsets,
untile_memlets,
@@ -458,6 +459,40 @@ def visit_HorizontalExecution(
write_memlets=write_memlets,
)
+ for memlet in [*read_memlets, *write_memlets]:
+ """
+ This loop handles the special case of a tasklet performing array access.
+ The memlet should pass the full array shape (no tiling) and
+ the tasklet expression for array access should use all explicit indexes.
+ """
+ array_ndims = len(global_ctx.arrays[memlet.field].shape)
+ field_decl = global_ctx.library_node.field_decls[memlet.field]
+ # calculate array subset on original memlet
+ memlet_subset = make_dace_subset(
+ global_ctx.library_node.access_infos[memlet.field],
+ memlet.access_info,
+ field_decl.data_dims,
+ )
+ # select index values for single-point grid access
+ memlet_data_index = [
+ dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32)
+ for dim_range, dim_size in zip(memlet_subset, memlet_subset.size())
+ if dim_size == 1
+ ]
+ if len(memlet_data_index) < array_ndims:
+ reshape_memlet = False
+ for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess):
+ if access_node.data_index and access_node.name == memlet.connector:
+ access_node.data_index = memlet_data_index + access_node.data_index
+ assert len(access_node.data_index) == array_ndims
+ reshape_memlet = True
+ if reshape_memlet:
+ # ensure that memlet symbols used for array indexing are defined in context
+ for sym in memlet.access_info.grid_subset.free_symbols:
+ symbol_collector.add_symbol(sym)
+ # set full shape on memlet
+ memlet.access_info = global_ctx.library_node.access_infos[memlet.field]
+
for item in reversed(expansion_items):
iteration_ctx = iteration_ctx.pop()
dcir_node = self._process_iteration_item(
diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py
index ddcb719b5f..bd8c08034c 100644
--- a/src/gt4py/cartesian/gtc/dace/nodes.py
+++ b/src/gt4py/cartesian/gtc/dace/nodes.py
@@ -121,7 +121,7 @@ def __init__(
*args,
**kwargs,
):
- super().__init__(name=name, *args, **kwargs)
+ super().__init__(*args, name=name, **kwargs)
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos
diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py
index 28ebc8cd8e..0366317360 100644
--- a/src/gt4py/cartesian/gtc/daceir.py
+++ b/src/gt4py/cartesian/gtc/daceir.py
@@ -536,7 +536,7 @@ def union(self, other):
else:
assert (
isinstance(interval2, (TileInterval, DomainInterval))
- and isinstance(interval1, IndexWithExtent)
+ and isinstance(interval1, (IndexWithExtent, DomainInterval))
) or (
isinstance(interval1, (TileInterval, DomainInterval))
and isinstance(interval2, IndexWithExtent)
@@ -573,7 +573,7 @@ def overapproximated_shape(self):
def apply_iteration(self, grid_subset: GridSubset):
res_intervals = dict(self.grid_subset.intervals)
for axis, field_interval in self.grid_subset.intervals.items():
- if axis in grid_subset.intervals:
+ if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval):
grid_interval = grid_subset.intervals[axis]
assert isinstance(field_interval, IndexWithExtent)
extent = field_interval.extent
diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py
index 8f1b9e6554..bc744b3ccc 100644
--- a/src/gt4py/eve/datamodels/core.py
+++ b/src/gt4py/eve/datamodels/core.py
@@ -814,7 +814,7 @@ def concretize(
""" # noqa: RST301 # doctest conventions confuse RST validator
concrete_cls: Type[DataModelT] = _make_concrete_with_cache(
- datamodel_cls, *type_args, class_name=class_name, module=module
+ datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type]
)
assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls)
diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py
index cd7e71588f..74c5bd41bb 100644
--- a/src/gt4py/eve/trees.py
+++ b/src/gt4py/eve/trees.py
@@ -133,7 +133,7 @@ def _pre_walk_items(
yield from _pre_walk_items(child, __key__=key)
-def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
+def _pre_walk_values(node: TreeLike) -> Iterable:
"""Create a pre-order tree traversal iterator of values."""
yield node
for child in iter_children_values(node):
@@ -153,7 +153,7 @@ def _post_walk_items(
yield __key__, node
-def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
+def _post_walk_values(node: TreeLike) -> Iterable:
"""Create a post-order tree traversal iterator of values."""
if (iter_children_values := getattr(node, "iter_children_values", None)) is not None:
for child in iter_children_values():
diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py
index 7104f7658f..624407f319 100644
--- a/src/gt4py/eve/utils.py
+++ b/src/gt4py/eve/utils.py
@@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]:
[('a', 'b', 'c'), (1, 2, 3)]
"""
- return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args
+ return XIterable(zip(*self.iterator))
@typing.overload
def islice(self, __stop: int) -> XIterable[T]:
@@ -1536,7 +1536,7 @@ def reduceby(
) -> Dict[K, S]:
...
- def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables
+ def reduceby(
self,
bin_op_func: Callable[[S, T], S],
key: Union[str, List[K], Callable[[T], K]],
diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py
index cbd5735949..1398af5f03 100644
--- a/src/gt4py/next/__init__.py
+++ b/src/gt4py/next/__init__.py
@@ -39,6 +39,11 @@
index_field,
np_as_located_field,
)
+from .program_processors.runners.gtfn import (
+ run_gtfn_cached as gtfn_cpu,
+ run_gtfn_gpu_cached as gtfn_gpu,
+)
+from .program_processors.runners.roundtrip import backend as itir_python
__all__ = [
@@ -74,5 +79,9 @@
"field_operator",
"program",
"scan_operator",
+ # from program_processor
+ "gtfn_cpu",
+ "gtfn_gpu",
+ "itir_python",
*fbuiltins.__all__,
]
diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py
index 29d606ccc0..949f4b461a 100644
--- a/src/gt4py/next/common.py
+++ b/src/gt4py/next/common.py
@@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange:
return UnitRange(max(self.start, other.start), min(self.stop, other.stop))
def __contains__(self, value: Any) -> bool:
- return (
- isinstance(value, core_defs.INTEGRAL_TYPES)
- and value >= self.start
- and value < self.stop
- )
+ # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604)
+ # and remove int cast, once the related mypy bug (#16358) gets fixed
+ if isinstance(value, core_defs.INTEGRAL_TYPES):
+ return self.start <= cast(int, value) < self.stop
+ else:
+ return False
def __le__(self, other: UnitRange) -> bool:
return self.start >= other.start and self.stop <= other.stop
@@ -574,38 +575,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _
...
-# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol
-class NextGTDimsInterface(Protocol):
+# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol.
+class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol):
"""
- Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions.
+ Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.
- The dimension names are objects of type :class:`Dimension`, in contrast to
- :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics,
- see :class:`~gt4py._core.definitions.GTDimsInterface` .
+ Note:
+ - A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided.
+ - No implementation of `__gt_origin__` is provided because of infinite fields.
"""
@property
- def __gt_dims__(self) -> tuple[Dimension, ...]:
+ def __gt_domain__(self) -> Domain:
+ # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`)
+ # to allow implementations without having to import gtx.Domain.
...
-
-# TODO(egparedes): add support for this new protocol in the cartesian module
-class GTFieldInterface(Protocol):
- """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`."""
-
@property
- def __gt_domain__(self) -> Domain:
- ...
+ def __gt_dims__(self) -> tuple[str, ...]:
+ return tuple(d.value for d in self.__gt_domain__.dims)
@extended_runtime_checkable
-class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]):
+class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]):
__gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher]
@property
def domain(self) -> Domain:
...
+ @property
+ def __gt_domain__(self) -> Domain:
+ return self.domain
+
@property
def codomain(self) -> type[core_defs.ScalarT] | Dimension:
...
@@ -923,10 +925,6 @@ def asnumpy(self) -> Never:
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))
- @property
- def __gt_dims__(self) -> tuple[Dimension, ...]:
- return self.domain.dims
-
@property
def __gt_origin__(self) -> Never:
raise TypeError("'CartesianConnectivity' does not support this operation.")
diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py
index 8bd2673db9..9fc1b42038 100644
--- a/src/gt4py/next/embedded/nd_array_field.py
+++ b/src/gt4py/next/embedded/nd_array_field.py
@@ -107,10 +107,6 @@ def domain(self) -> common.Domain:
def shape(self) -> tuple[int, ...]:
return self._ndarray.shape
- @property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return self._domain.dims
-
@property
def __gt_origin__(self) -> tuple[int, ...]:
assert common.Domain.is_finite(self._domain)
diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py
index f50ace7687..0992401ebb 100644
--- a/src/gt4py/next/embedded/operators.py
+++ b/src/gt4py/next/embedded/operators.py
@@ -17,7 +17,7 @@
from gt4py import eve
from gt4py._core import definitions as core_defs
-from gt4py.next import common, constructors, utils
+from gt4py.next import common, constructors, errors, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context
@@ -77,17 +77,20 @@ def scan_loop(hpos):
def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
if "out" in kwargs:
# called from program or direct field_operator as program
- offset_provider = kwargs.pop("offset_provider", None)
-
new_context_kwargs = {}
if embedded_context.within_context():
# called from program
- assert offset_provider is None
+ assert "offset_provider" not in kwargs
else:
# field_operator as program
+ if "offset_provider" not in kwargs:
+ raise errors.MissingArgumentError(None, "offset_provider", True)
+ offset_provider = kwargs.pop("offset_provider", None)
+
new_context_kwargs["offset_provider"] = offset_provider
out = kwargs.pop("out")
+
domain = kwargs.pop("domain", None)
flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,))
@@ -105,7 +108,10 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
domain=out_domain,
)
else:
- # called from other field_operator
+ # called from other field_operator or missing `out` argument
+ if "offset_provider" in kwargs:
+ # assuming we wanted to call the field_operator as program, otherwise `offset_provider` would not be there
+ raise errors.MissingArgumentError(None, "out", True)
return op(*args, **kwargs)
diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py
index 61441e83b9..dd48d6f0f9 100644
--- a/src/gt4py/next/errors/__init__.py
+++ b/src/gt4py/next/errors/__init__.py
@@ -21,6 +21,7 @@
from .exceptions import (
DSLError,
InvalidParameterAnnotationError,
+ MissingArgumentError,
MissingAttributeError,
MissingParameterAnnotationError,
UndefinedSymbolError,
@@ -33,6 +34,7 @@
"InvalidParameterAnnotationError",
"MissingAttributeError",
"MissingParameterAnnotationError",
+ "MissingArgumentError",
"UndefinedSymbolError",
"UnsupportedPythonFeatureError",
"set_verbose_exceptions",
diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py
index 081453c023..858f969447 100644
--- a/src/gt4py/next/errors/exceptions.py
+++ b/src/gt4py/next/errors/exceptions.py
@@ -81,6 +81,18 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None:
self.attr_name = attr_name
+class MissingArgumentError(DSLError):
+ arg_name: str
+ is_kwarg: bool
+
+ def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: bool) -> None:
+ super().__init__(
+ location, f"Expected {'keyword-' if is_kwarg else ''}argument '{arg_name}'."
+ )
+ self.attr_name = arg_name
+ self.is_kwarg = is_kwarg
+
+
class TypeError_(DSLError):
def __init__(self, location: Optional[SourceLocation], message: str) -> None:
super().__init__(location, message)
diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py
index 53159008f0..05cbe1c882 100644
--- a/src/gt4py/next/ffront/decorator.py
+++ b/src/gt4py/next/ffront/decorator.py
@@ -29,10 +29,11 @@
from devtools import debug
+from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
-from gt4py.next import allocators as next_allocators, embedded as next_embedded
+from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
@@ -61,11 +62,10 @@
sym,
)
from gt4py.next.program_processors import processor_interface as ppi
-from gt4py.next.program_processors.runners import roundtrip
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
-DEFAULT_BACKEND: Callable = roundtrip.executor
+DEFAULT_BACKEND: Callable = None
def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]:
@@ -176,15 +176,15 @@ class Program:
past_node: past.Program
closure_vars: dict[str, Any]
- definition: Optional[types.FunctionType] = None
- backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
- grid_type: Optional[GridType] = None
+ definition: Optional[types.FunctionType]
+ backend: Optional[ppi.ProgramExecutor]
+ grid_type: Optional[GridType]
@classmethod
def from_function(
cls,
definition: types.FunctionType,
- backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
+ backend: Optional[ppi.ProgramExecutor],
grid_type: Optional[GridType] = None,
) -> Program:
source_def = SourceDefinition.from_function(definition)
@@ -453,27 +453,32 @@ def _process_args(self, args: tuple, kwargs: dict):
) from err
full_args = [*args]
+ full_kwargs = {**kwargs}
for index, param in enumerate(self.past_node.params):
if param.id in self.bound_args.keys():
- full_args.insert(index, self.bound_args[param.id])
+ if index < len(full_args):
+ full_args.insert(index, self.bound_args[param.id])
+ else:
+ full_kwargs[str(param.id)] = self.bound_args[param.id]
- return super()._process_args(tuple(full_args), kwargs)
+ return super()._process_args(tuple(full_args), full_kwargs)
@functools.cached_property
def itir(self):
new_itir = super().itir
for new_clos in new_itir.closures:
- for key in self.bound_args.keys():
+ new_args = [ref(inp.id) for inp in new_clos.inputs]
+ for key, value in self.bound_args.items():
index = next(
index
for index, closure_input in enumerate(new_clos.inputs)
if closure_input.id == key
)
+ new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator(
+ literal_from_value(value)
+ )
new_clos.inputs.pop(index)
- new_args = [ref(inp.id) for inp in new_clos.inputs]
params = [sym(inp.id) for inp in new_clos.inputs]
- for value in self.bound_args.values():
- new_args.append(promote_to_const_iterator(literal_from_value(value)))
expr = itir.FunCall(
fun=new_clos.stencil,
args=new_args,
@@ -495,7 +500,7 @@ def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.Functi
def program(
definition=None,
*,
- backend=None,
+ backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution)
grid_type=None,
) -> Program | Callable[[types.FunctionType], Program]:
"""
@@ -517,7 +522,9 @@ def program(
"""
def program_inner(definition: types.FunctionType) -> Program:
- return Program.from_function(definition, backend, grid_type)
+ return Program.from_function(
+ definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type
+ )
return program_inner if definition is None else program_inner(definition)
@@ -549,9 +556,9 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
foast_node: OperatorNodeT
closure_vars: dict[str, Any]
- definition: Optional[types.FunctionType] = None
- backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
- grid_type: Optional[GridType] = None
+ definition: Optional[types.FunctionType]
+ backend: Optional[ppi.ProgramExecutor]
+ grid_type: Optional[GridType]
operator_attributes: Optional[dict[str, Any]] = None
_program_cache: dict = dataclasses.field(default_factory=dict)
@@ -559,7 +566,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
def from_function(
cls,
definition: types.FunctionType,
- backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
+ backend: Optional[ppi.ProgramExecutor],
grid_type: Optional[GridType] = None,
*,
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator,
@@ -686,6 +693,7 @@ def as_program(
self._program_cache[hash_] = Program(
past_node=past_node,
closure_vars=closure_vars,
+ definition=None,
backend=self.backend,
grid_type=self.grid_type,
)
@@ -698,7 +706,12 @@ def __call__(
) -> None:
if not next_embedded.context.within_context() and self.backend is not None:
# non embedded execution
- offset_provider = kwargs.pop("offset_provider", None)
+ if "offset_provider" not in kwargs:
+ raise errors.MissingArgumentError(None, "offset_provider", True)
+ offset_provider = kwargs.pop("offset_provider")
+
+ if "out" not in kwargs:
+ raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
@@ -744,7 +757,7 @@ def field_operator(
...
-def field_operator(definition=None, *, backend=None, grid_type=None):
+def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None):
"""
Generate an implementation of the field operator from a Python function object.
@@ -762,7 +775,9 @@ def field_operator(definition=None, *, backend=None, grid_type=None):
"""
def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]:
- return FieldOperator.from_function(definition, backend, grid_type)
+ return FieldOperator.from_function(
+ definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type
+ )
return field_operator_inner if definition is None else field_operator_inner(definition)
@@ -798,7 +813,7 @@ def scan_operator(
axis: Dimension,
forward: bool = True,
init: core_defs.Scalar = 0.0,
- backend=None,
+ backend=eve.NOTHING,
grid_type: GridType = None,
) -> (
FieldOperator[foast.ScanOperator]
@@ -836,7 +851,7 @@ def scan_operator(
def scan_operator_inner(definition: types.FunctionType) -> FieldOperator:
return FieldOperator.from_function(
definition,
- backend,
+ DEFAULT_BACKEND if backend is eve.NOTHING else backend,
grid_type,
operator_node_cls=foast.ScanOperator,
operator_attributes={"axis": axis, "forward": forward, "init": init},
diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py
index 278dde9180..cd75538da7 100644
--- a/src/gt4py/next/ffront/fbuiltins.py
+++ b/src/gt4py/next/ffront/fbuiltins.py
@@ -15,7 +15,7 @@
import dataclasses
import functools
import inspect
-from builtins import bool, float, int, tuple
+from builtins import bool, float, int, tuple # noqa: A004
from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast
import numpy as np
diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py
index ef70a2e645..390bec4312 100644
--- a/src/gt4py/next/iterator/embedded.py
+++ b/src/gt4py/next/iterator/embedded.py
@@ -172,7 +172,7 @@ class LocatedField(Protocol):
@property
@abc.abstractmethod
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
+ def __gt_domain__(self) -> common.Domain:
...
# TODO(havogt): define generic Protocol to provide a concrete return type
@@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any:
@property
def __gt_origin__(self) -> tuple[int, ...]:
- return tuple([0] * len(self.__gt_dims__))
+ return tuple([0] * len(self.__gt_domain__.dims))
@runtime_checkable
@@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]:
def _get_axes(
field_or_tuple: LocatedField | tuple,
) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField
+ return _get_domain(field_or_tuple).dims
+
+
+def _get_domain(
+ field_or_tuple: LocatedField | tuple,
+) -> common.Domain: # arbitrary nesting of tuples of LocatedField
if isinstance(field_or_tuple, tuple):
- first = _get_axes(field_or_tuple[0])
- assert all(first == _get_axes(f) for f in field_or_tuple)
+ first = _get_domain(field_or_tuple[0])
+ assert all(first == _get_domain(f) for f in field_or_tuple)
return first
else:
- return field_or_tuple.__gt_dims__
+ return field_or_tuple.__gt_domain__
def _single_vertical_idx(
@@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField):
_ndarrayfield: common.Field
@property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return self._ndarrayfield.__gt_dims__
+ def __gt_domain__(self) -> common.Domain:
+ return self._ndarrayfield.__gt_domain__
def _translate_named_indices(
self, _named_indices: NamedFieldIndices
) -> common.AbsoluteIndexSequence:
named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = {
- d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__
+ d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims
}
domain_slice: list[common.NamedRange | common.NamedIndex] = []
for d, v in named_indices.items():
@@ -1046,8 +1052,8 @@ class IndexField(common.Field):
_dimension: common.Dimension
@property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return (self._dimension,)
+ def __gt_domain__(self) -> common.Domain:
+ return self.domain
@property
def __gt_origin__(self) -> tuple[int, ...]:
@@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]):
_value: core_defs.ScalarT
@property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return tuple()
+ def __gt_domain__(self) -> common.Domain:
+ return self.domain
@property
def __gt_origin__(self) -> tuple[int, ...]:
@@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices:
class TupleOfFields(TupleField):
def __init__(self, data):
self.data = data
- self.__gt_dims__ = _get_axes(data)
+ self.__gt_domain__ = _get_domain(data)
def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
return _build_tuple_result(self.data, named_indices)
diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py
index 30fec1f9fd..05ebd02352 100644
--- a/src/gt4py/next/iterator/tracing.py
+++ b/src/gt4py/next/iterator/tracing.py
@@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg):
# other `np.int32`). We just ignore the error here and postpone fixing this to when
# the new storages land (The implementation here works for LocatedFieldImpl).
- return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__)
+ return common.is_field(arg) and any(dim is None for dim in arg.domain.dims)
def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py
index d9d3d18213..0033f36cab 100644
--- a/src/gt4py/next/iterator/transforms/global_tmps.py
+++ b/src/gt4py/next/iterator/transforms/global_tmps.py
@@ -22,6 +22,7 @@
from gt4py.eve import Coerced, NodeTranslator
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
+from gt4py.next import common
from gt4py.next.iterator import ir, type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
@@ -437,9 +438,12 @@ def _group_offsets(
return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly
-def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]):
+def update_domains(
+ node: FencilWithTemporaries,
+ offset_provider: Mapping[str, Any],
+ symbolic_sizes: Optional[dict[str, str]],
+):
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)
-
closures: list[ir.StencilClosure] = []
domains = dict[str, ir.FunCall]()
for closure in reversed(node.fencil.closures):
@@ -479,16 +483,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
# cartesian shift
dim = offset_provider[offset_name].value
consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset)
- elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider):
+ elif isinstance(offset_provider[offset_name], common.Connectivity):
# unstructured shift
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
- consumed_domain.ranges.pop(old_axis)
- assert new_axis not in consumed_domain.ranges
- consumed_domain.ranges[new_axis] = SymbolicRange(
- im.literal("0", ir.INTEGER_INDEX_BUILTIN),
- im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN),
+
+ assert new_axis not in consumed_domain.ranges or old_axis == new_axis
+
+ if symbolic_sizes is None:
+ new_range = SymbolicRange(
+ im.literal("0", ir.INTEGER_INDEX_BUILTIN),
+ im.literal(
+ str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN
+ ),
+ )
+ else:
+ new_range = SymbolicRange(
+ im.literal("0", ir.INTEGER_INDEX_BUILTIN),
+ im.ref(symbolic_sizes[new_axis]),
+ )
+ consumed_domain.ranges = dict(
+ (axis, range_) if axis != old_axis else (new_axis, new_range)
+ for axis, range_ in consumed_domain.ranges.items()
)
else:
raise NotImplementedError
@@ -570,7 +587,11 @@ class CreateGlobalTmps(NodeTranslator):
"""
def visit_FencilDefinition(
- self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any]
+ self,
+ node: ir.FencilDefinition,
+ *,
+ offset_provider: Mapping[str, Any],
+ symbolic_sizes: Optional[dict[str, str]],
) -> FencilWithTemporaries:
# Split closures on lifted function calls and introduce temporaries
res = split_closures(node, offset_provider=offset_provider)
@@ -581,6 +602,6 @@ def visit_FencilDefinition(
# Perform an eta-reduction which should put all calls at the highest level of a closure
res = EtaReduction().visit(res)
# Perform a naive extent analysis to compute domain sizes of closures and temporaries
- res = update_domains(res, offset_provider)
+ res = update_domains(res, offset_provider, symbolic_sizes)
# Use type inference to determine the data type of the temporaries
return collect_tmps_info(res, offset_provider=offset_provider)
diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py
index 2e05391634..08897861c2 100644
--- a/src/gt4py/next/iterator/transforms/pass_manager.py
+++ b/src/gt4py/next/iterator/transforms/pass_manager.py
@@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import enum
+from typing import Optional
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import simple_inline_heuristic
@@ -81,6 +82,7 @@ def apply_common_transforms(
common_subexpression_elimination=True,
force_inline_lambda_args=False,
unconditionally_collapse_tuples=False,
+ symbolic_domain_sizes: Optional[dict[str, str]] = None,
):
if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
@@ -147,7 +149,9 @@ def apply_common_transforms(
if lift_mode != LiftMode.FORCE_INLINE:
assert offset_provider is not None
- ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
+ ir = CreateGlobalTmps().visit(
+ ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes
+ )
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py
new file mode 100644
index 0000000000..ac71f2747d
--- /dev/null
+++ b/src/gt4py/next/iterator/transforms/power_unrolling.py
@@ -0,0 +1,84 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+import dataclasses
+import math
+
+from gt4py.eve import NodeTranslator
+from gt4py.next.iterator import ir
+from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas
+
+
+def _is_power_call(
+ node: ir.FunCall,
+) -> bool:
+ """Match expressions of the form `power(base, integral_literal)`."""
+ return (
+ isinstance(node.fun, ir.SymRef)
+ and node.fun.id == "power"
+ and isinstance(node.args[1], ir.Literal)
+ and float(node.args[1].value) == int(node.args[1].value)
+ and node.args[1].value >= im.literal_from_value(0).value
+ )
+
+
+def _compute_integer_power_of_two(exp: int) -> int:
+ return math.floor(math.log2(exp))
+
+
+@dataclasses.dataclass
+class PowerUnrolling(NodeTranslator):
+ max_unroll: int
+
+ @classmethod
+ def apply(cls, node: ir.Node, max_unroll: int = 5) -> ir.Node:
+ return cls(max_unroll=max_unroll).visit(node)
+
+ def visit_FunCall(self, node: ir.FunCall):
+ new_node = self.generic_visit(node)
+
+ if _is_power_call(new_node):
+ assert len(new_node.args) == 2
+ # Check if unroll should be performed or if exponent is too large
+ base, exponent = new_node.args[0], int(new_node.args[1].value)
+ if 1 <= exponent <= self.max_unroll:
+ # Calculate and store powers of two of the base as long as they are smaller than the exponent.
+ # Do the same (using the stored values) with the remainder and multiply computed values.
+ pow_cur = _compute_integer_power_of_two(exponent)
+ pow_max = pow_cur
+ remainder = exponent
+
+ # Build target expression
+ ret = im.ref(f"power_{2 ** pow_max}")
+ remainder -= 2**pow_cur
+ while remainder > 0:
+ pow_cur = _compute_integer_power_of_two(remainder)
+ remainder -= 2**pow_cur
+
+ ret = im.multiplies_(ret, f"power_{2 ** pow_cur}")
+
+ # Nest target expression to avoid multiple redundant evaluations
+ for i in range(pow_max, 0, -1):
+ ret = im.let(
+ f"power_{2 ** i}",
+ im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"),
+ )(ret)
+ ret = im.let("power_1", base)(ret)
+
+ # Simplify expression in case of SymRef by resolving let statements
+ if isinstance(base, ir.SymRef):
+ return InlineLambdas.apply(ret, opcount_preserving=True)
+ else:
+ return ret
+ return new_node
diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py
index 68627cfd89..d65f67b266 100644
--- a/src/gt4py/next/iterator/type_inference.py
+++ b/src/gt4py/next/iterator/type_inference.py
@@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints):
axis = offset_provider[offset]
if isinstance(axis, gtx.Dimension):
continue # Cartesian shifts don’t change the location type
- elif isinstance(
- axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider)
- ):
+ elif isinstance(axis, Connectivity):
assert (
axis.origin_axis.kind
== axis.neighbor_axis.kind
@@ -964,7 +962,7 @@ def visit_FencilDefinition(
def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None:
for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES):
try:
- child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined]
+ child_node.annex.type = types[id(child_node)]
except KeyError:
if not (
isinstance(child_node, ir.SymRef)
diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py
index ed8b768972..3a82f9c738 100644
--- a/src/gt4py/next/otf/workflow.py
+++ b/src/gt4py/next/otf/workflow.py
@@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self:
if not dataclasses.is_dataclass(self):
raise TypeError(f"'{self.__class__}' is not a dataclass.")
assert not isinstance(self, type)
- return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`?
+ return dataclasses.replace(self, **kwargs)
class ChainableWorkflowMixin(Workflow[StartT, EndT]):
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py
deleted file mode 100644
index 4183f52550..0000000000
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# GT4Py - GridTools Framework
-#
-# Copyright (c) 2014-2023, ETH Zurich
-# All rights reserved.
-#
-# This file is part of the GT4Py project and the GridTools framework.
-# GT4Py is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from typing import Any
-
-import gt4py.next.iterator.ir as itir
-from gt4py.eve import codegen
-from gt4py.eve.exceptions import EveValueError
-from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms
-from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
-from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
-from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
-
-
-def _lower(
- program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any
-):
- offset_provider = kwargs.get("offset_provider")
- assert isinstance(offset_provider, dict)
- if enable_itir_transforms:
- program = apply_common_transforms(
- program,
- lift_mode=kwargs.get("lift_mode"),
- offset_provider=offset_provider,
- unroll_reduce=do_unroll,
- unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
- )
- gtfn_ir = GTFN_lowering.apply(
- program,
- offset_provider=offset_provider,
- column_axis=kwargs.get("column_axis"),
- )
- return gtfn_ir
-
-
-def generate(
- program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any
-) -> str:
- if kwargs.get("imperative", False):
- try:
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=False,
- **kwargs,
- )
- except EveValueError:
- # if we don't unroll, there may be lifts left in the itir which can't be lowered to
- # gtfn. In this case, just retry with unrolled reductions.
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=True,
- **kwargs,
- )
- gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs)
- generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs)
- else:
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=True,
- **kwargs,
- )
- generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs)
- return codegen.format_source("cpp", generated_code, style="LLVM")
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
index 4abdaa6eea..718fef72af 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
@@ -15,21 +15,24 @@
from __future__ import annotations
import dataclasses
+import functools
import warnings
from typing import Any, Final, Optional
import numpy as np
from gt4py._core import definitions as core_defs
-from gt4py.eve import trees, utils
+from gt4py.eve import codegen, trees, utils
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
-from gt4py.next.iterator.transforms import LiftMode
+from gt4py.next.iterator.transforms import LiftMode, pass_manager
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.binding import cpp_interface, interface
-from gt4py.next.program_processors.codegens.gtfn import gtfn_backend
+from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
+from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
+from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
from gt4py.next.type_system import type_specifications as ts, type_translation
@@ -54,6 +57,7 @@ class GTFNTranslationStep(
use_imperative_backend: bool = False
lift_mode: Optional[LiftMode] = None
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
+ symbolic_domain_sizes: Optional[dict[str, str]] = None
def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
match self.device_type:
@@ -171,6 +175,70 @@ def _process_connectivity_args(
return parameters, arg_exprs
+ def _preprocess_program(
+ self,
+ program: itir.FencilDefinition,
+ offset_provider: dict[str, Connectivity | Dimension],
+ runtime_lift_mode: Optional[LiftMode] = None,
+ ) -> itir.FencilDefinition:
+ # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
+ # to the interface of all (or at least all of concern) backends, but instead should be
+ # configured in the backend itself (like it is here), until then we respect the argument
+ # here and warn the user if it differs from the one configured.
+ lift_mode = runtime_lift_mode or self.lift_mode
+ if lift_mode != self.lift_mode:
+ warnings.warn(
+ f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
+ f"overriden to be {str(runtime_lift_mode)} at runtime."
+ )
+
+ if not self.enable_itir_transforms:
+ return program
+
+ apply_common_transforms = functools.partial(
+ pass_manager.apply_common_transforms,
+ lift_mode=lift_mode,
+ offset_provider=offset_provider,
+ # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
+ unconditionally_collapse_tuples=True,
+ symbolic_domain_sizes=self.symbolic_domain_sizes,
+ )
+
+ new_program = apply_common_transforms(
+ program, unroll_reduce=not self.use_imperative_backend
+ )
+
+ if self.use_imperative_backend and any(
+ node.id == "neighbors"
+ for node in new_program.pre_walk_values().if_isinstance(itir.SymRef)
+ ):
+ # if we don't unroll, there may be lifts left in the itir which can't be lowered to
+ # gtfn. In this case, just retry with unrolled reductions.
+ new_program = apply_common_transforms(program, unroll_reduce=True)
+
+ return new_program
+
+ def generate_stencil_source(
+ self,
+ program: itir.FencilDefinition,
+ offset_provider: dict[str, Connectivity | Dimension],
+ column_axis: Optional[common.Dimension],
+ runtime_lift_mode: Optional[LiftMode] = None,
+ ) -> str:
+ new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode)
+ gtfn_ir = GTFN_lowering.apply(
+ new_program,
+ offset_provider=offset_provider,
+ column_axis=column_axis,
+ )
+
+ if self.use_imperative_backend:
+ gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir)
+ generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
+ else:
+ generated_code = GTFNCodegen.apply(gtfn_ir)
+ return codegen.format_source("cpp", generated_code, style="LLVM")
+
def __call__(
self,
inp: stages.ProgramCall,
@@ -190,18 +258,6 @@ def __call__(
inp.kwargs["offset_provider"]
)
- # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
- # to the interface of all (or at least all of concern) backends, but instead should be
- # configured in the backend itself (like it is here), until then we respect the argument
- # here and warn the user if it differs from the one configured.
- runtime_lift_mode = inp.kwargs.pop("lift_mode", None)
- lift_mode = runtime_lift_mode or self.lift_mode
- if runtime_lift_mode != self.lift_mode:
- warnings.warn(
- f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
- "overriden to be {str(runtime_lift_mode)} at runtime."
- )
-
# combine into a format that is aligned with what the backend expects
parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters
backend_arg = self._backend_type()
@@ -213,12 +269,11 @@ def __call__(
f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});"
)
decl_src = cpp_interface.render_function_declaration(function, body=decl_body)
- stencil_src = gtfn_backend.generate(
+ stencil_src = self.generate_stencil_source(
program,
- enable_itir_transforms=self.enable_itir_transforms,
- lift_mode=lift_mode,
- imperative=self.use_imperative_backend,
- **inp.kwargs,
+ inp.kwargs["offset_provider"],
+ inp.kwargs.get("column_axis", None),
+ inp.kwargs.get("lift_mode", None),
)
source_code = interface.format_source(
self._language_settings(),
diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py
index f9fa154641..27dec77ed1 100644
--- a/src/gt4py/next/program_processors/formatters/gtfn.py
+++ b/src/gt4py/next/program_processors/formatters/gtfn.py
@@ -15,10 +15,19 @@
from typing import Any
from gt4py.next.iterator import ir as itir
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep
from gt4py.next.program_processors.processor_interface import program_formatter
+from gt4py.next.program_processors.runners.gtfn import gtfn_executor
@program_formatter
def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str:
- return generate(program, **kwargs)
+ # TODO(tehrengruber): This is a little ugly. Revisit.
+ gtfn_translation = gtfn_executor.otf_workflow.translation
+ assert isinstance(gtfn_translation, GTFNTranslationStep)
+ return gtfn_translation.generate_stencil_source(
+ program,
+ offset_provider=kwargs.get("offset_provider", None),
+ column_axis=kwargs.get("column_axis", None),
+ runtime_lift_mode=kwargs.get("lift_mode", None),
+ )
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
index 7fd4794e57..fdd8a61054 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
@@ -260,10 +260,12 @@ def build_sdfg_from_itir(
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
- # TODO: According to Lex one should build the SDFG first in a general mannor.
- # Generalisation to a particular device should happen only at the end.
- sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu)
+ sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
+ if sdfg is None:
+ raise RuntimeError(f"Visit failed for program {program.id}.")
+
+ # run DaCe transformations to simplify the SDFG
sdfg.simplify()
# run DaCe auto-optimization heuristics
@@ -274,6 +276,9 @@ def build_sdfg_from_itir(
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)
+ if on_gpu:
+ sdfg.apply_gpu_transformations()
+
return sdfg
@@ -283,7 +288,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
compiler_args = kwargs.get("compiler_args", None) # `None` will take default.
build_type = kwargs.get("build_type", "RelWithDebInfo")
on_gpu = kwargs.get("on_gpu", False)
- auto_optimize = kwargs.get("auto_optimize", False)
+ auto_optimize = kwargs.get("auto_optimize", True)
lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
index e3b5ddf2ac..fb2f82fed0 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
@@ -99,20 +99,17 @@ class ItirToSDFG(eve.NodeVisitor):
offset_provider: dict[str, Any]
node_types: dict[int, next_typing.Type]
unique_id: int
- use_gpu_storage: bool
def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTableOffsetProvider],
column_axis: Optional[Dimension] = None,
- use_gpu_storage: bool = False,
):
self.param_types = param_types
self.column_axis = column_axis
self.offset_provider = offset_provider
self.storage_types = {}
- self.use_gpu_storage = use_gpu_storage
def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
if isinstance(type_, ts.FieldType):
@@ -123,14 +120,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset
else None
)
dtype = as_dace_type(type_.dtype)
- storage = (
- dace.dtypes.StorageType.GPU_Global
- if self.use_gpu_storage
- else dace.dtypes.StorageType.Default
- )
- sdfg.add_array(
- name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage
- )
+ sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype)
elif isinstance(type_, ts.ScalarType):
sdfg.add_symbol(name, as_dace_type(type_))
@@ -246,7 +236,6 @@ def visit_StencilClosure(
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
- storage=array_table[name].storage,
transient=True,
)
closure_init_state.add_nedge(
@@ -261,7 +250,6 @@ def visit_StencilClosure(
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
- storage=array_table[name].storage,
)
else:
assert isinstance(self.storage_types[name], ts.ScalarType)
diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py
index 88a8347fe4..12649bf620 100644
--- a/src/gt4py/next/type_system/type_translation.py
+++ b/src/gt4py/next/type_system/type_translation.py
@@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec:
elif isinstance(value, common.Dimension):
symbol_type = ts.DimensionType(dim=value)
elif common.is_field(value):
- dims = list(value.__gt_dims__)
+ dims = list(value.domain.dims)
dtype = from_type_hint(value.dtype.scalar_type)
symbol_type = ts.FieldType(dims=dims, dtype=dtype)
elif isinstance(value, tuple):
diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py
index 0f7cf5d0ab..4e7ebb0c21 100644
--- a/src/gt4py/storage/cartesian/utils.py
+++ b/src/gt4py/storage/cartesian/utils.py
@@ -192,6 +192,10 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray:
def asarray(
array: FieldLike, *, device: Literal["cpu", "gpu", None] = None
) -> np.ndarray | cp.ndarray:
+ if hasattr(array, "ndarray"):
+ # extract the buffer from a gt4py.next.Field
+ # TODO(havogt): probably `Field` should provide the array interface methods when applicable
+ array = array.ndarray
if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")):
return cp.asarray(array)
if device == "cpu" or (
diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py
index e580333bc8..8cfff12df4 100644
--- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py
+++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py
@@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree():
SymbolTableRootNode(
nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")],
)
- SymbolTableRootNode(
+ SymbolTableRootNode( # noqa: B018
nodes=[
SymbolChildNode(name="foo"),
SymbolRefChildNode(name="foo"),
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
index e25576ebde..1f5a1f0c48 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
@@ -22,6 +22,8 @@
import gt4py.next as gtx
from gt4py.next.ffront import decorator
from gt4py.next.iterator import ir as itir
+from gt4py.next.program_processors import processor_interface as ppi
+from gt4py.next.program_processors.runners import gtfn, roundtrip
try:
@@ -36,9 +38,10 @@
import next_tests.exclusion_matrices as definitions
+@ppi.program_executor
def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
"""Temporary default backend to not accidentally test the wrong backend."""
- raise ValueError("No backend selected. Backend selection is mandatory in tests.")
+ raise ValueError("No backend selected! Backend selection is mandatory in tests.")
OPTIONAL_PROCESSORS = []
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py
new file mode 100644
index 0000000000..0de953d85f
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import numpy as np
+
+import gt4py.next as gtx
+from gt4py.next import int32
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import cartesian_case
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
+ fieldview_backend,
+ reduction_setup,
+)
+
+
+def test_with_bound_args(cartesian_case):
+ @gtx.field_operator
+ def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
+ if not condition:
+ scalar = 0
+ return a + scalar
+
+ @gtx.program
+ def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
+ fieldop_bound_args(a, scalar, condition, out=out)
+
+ a = cases.allocate(cartesian_case, program_bound_args, "a")()
+ scalar = int32(1)
+ ref = a + scalar
+ out = cases.allocate(cartesian_case, program_bound_args, "out")()
+
+ prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
+ cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)
+
+
+def test_with_bound_args_order_args(cartesian_case):
+ @gtx.field_operator
+ def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField:
+ scalar = 0 if not condition else scalar
+ return a + scalar
+
+ @gtx.program(backend=cartesian_case.backend)
+ def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField):
+ fieldop_args(a, condition, scalar, out=out)
+
+ a = cases.allocate(cartesian_case, program_args, "a")()
+ out = cases.allocate(cartesian_case, program_args, "out")()
+
+ prog_bounds = program_args.with_bound_args(condition=True)
+ prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={})
+ np.allclose(out.asnumpy(), a.asnumpy() + int32(1))
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
index a08931628b..70c79d7b6c 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
@@ -898,26 +898,6 @@ def test_docstring(a: cases.IField):
cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a)
-def test_with_bound_args(cartesian_case):
- @gtx.field_operator
- def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
- if not condition:
- scalar = 0
- return a + a + scalar
-
- @gtx.program
- def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
- fieldop_bound_args(a, scalar, condition, out=out)
-
- a = cases.allocate(cartesian_case, program_bound_args, "a")()
- scalar = int32(1)
- ref = a + a + 1
- out = cases.allocate(cartesian_case, program_bound_args, "out")()
-
- prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
- cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)
-
-
def test_domain(cartesian_case):
@gtx.field_operator
def fieldop_domain(a: cases.IField) -> cases.IField:
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py
index 698dce2b5c..d100cd380c 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py
@@ -30,16 +30,6 @@
def test_external_local_field(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def testee(
inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32]
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
index e8d0c8b163..e2434d860a 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
@@ -46,16 +46,6 @@
ids=["positive_values", "negative_values"],
)
def test_maxover_execution_(unstructured_case, strategy):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
if unstructured_case.backend in [
gtfn.run_gtfn,
gtfn.run_gtfn_gpu,
@@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField:
@pytest.mark.uses_unstructured_shift
def test_minover_execution(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def minover(edge_f: cases.EField) -> cases.VField:
out = min_over(edge_f(V2E), axis=V2EDim)
@@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField:
@pytest.mark.uses_unstructured_shift
def test_reduction_execution(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def reduction(edge_f: cases.EField) -> cases.VField:
return neighbor_sum(edge_f(V2E), axis=V2EDim)
@@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField):
@pytest.mark.uses_unstructured_shift
def test_reduction_with_common_expression(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def testee(flux: cases.EField) -> cases.VField:
return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim)
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
index 167ccbb0a5..4444742c66 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
@@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import math
-from typing import Callable
+from typing import Callable, Optional
import numpy as np
import pytest
@@ -22,6 +22,7 @@
from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast
from gt4py.next.ffront.decorator import FieldOperator
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
+from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_translation
from next_tests.integration_tests import cases
@@ -39,7 +40,7 @@
# becomes easier.
-def make_builtin_field_operator(builtin_name: str):
+def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.ProgramExecutor]):
# TODO(tehrengruber): creating a field operator programmatically should be
# easier than what we need to do here.
# construct annotation dictionary containing the input argument and return
@@ -109,8 +110,9 @@ def make_builtin_field_operator(builtin_name: str):
return FieldOperator(
foast_node=typed_foast_node,
closure_vars=closure_vars,
- backend=None,
definition=None,
+ backend=backend,
+ grid_type=None,
)
@@ -129,9 +131,7 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp
expected = ref_impl(*inputs)
out = cartesian_case.as_field([IDim], np.zeros_like(expected))
- builtin_field_op = make_builtin_field_operator(builtin_name).with_backend(
- cartesian_case.backend
- )
+ builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.backend)
builtin_field_op(*inps, out=out, offset_provider={})
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
new file mode 100644
index 0000000000..788081b81e
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
@@ -0,0 +1,119 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+from numpy import int32, int64
+
+from gt4py import next as gtx
+from gt4py.next import common
+from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
+from gt4py.next.program_processors import otf_compile_executor
+from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import (
+ E2V,
+ Case,
+ KDim,
+ Vertex,
+ cartesian_case,
+ unstructured_case,
+)
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
+ reduction_setup,
+)
+from next_tests.toy_connectivity import Cell, Edge
+
+
+@pytest.fixture
+def run_gtfn_with_temporaries_and_symbolic_sizes():
+ return otf_compile_executor.OTFBackend(
+ executor=otf_compile_executor.OTFCompileExecutor(
+ name="run_gtfn_with_temporaries_and_sizes",
+ otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace(
+ translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace(
+ symbolic_domain_sizes={
+ "Cell": "num_cells",
+ "Edge": "num_edges",
+ "Vertex": "num_vertices",
+ },
+ ),
+ ),
+ ),
+ allocator=run_gtfn_with_temporaries.allocator,
+ )
+
+
+@pytest.fixture
+def testee():
+ @gtx.field_operator
+ def testee_op(a: cases.VField) -> cases.EField:
+ amul = a * 2
+ return amul(E2V[0]) + amul(E2V[1])
+
+ @gtx.program
+ def prog(
+ a: cases.VField,
+ out: cases.EField,
+ num_vertices: int32,
+ num_edges: int64,
+ num_cells: int32,
+ ):
+ testee_op(a, out=out)
+
+ return prog
+
+
+def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup):
+ unstructured_case = Case(
+ run_gtfn_with_temporaries_and_symbolic_sizes,
+ offset_provider=reduction_setup.offset_provider,
+ default_sizes={
+ Vertex: reduction_setup.num_vertices,
+ Edge: reduction_setup.num_edges,
+ Cell: reduction_setup.num_cells,
+ KDim: reduction_setup.k_levels,
+ },
+ grid_type=common.GridType.UNSTRUCTURED,
+ )
+
+ a = cases.allocate(unstructured_case, testee, "a")()
+ out = cases.allocate(unstructured_case, testee, "out")()
+
+ first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1])
+ ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs]
+
+ cases.verify(
+ unstructured_case,
+ testee,
+ a,
+ out,
+ reduction_setup.num_vertices,
+ reduction_setup.num_edges,
+ reduction_setup.num_cells,
+ inout=out,
+ ref=ref,
+ )
+
+
+def test_temporary_symbols(testee, reduction_setup):
+ itir_with_tmp = apply_common_transforms(
+ testee.itir,
+ lift_mode=LiftMode.FORCE_TEMPORARIES,
+ offset_provider=reduction_setup.offset_provider,
+ )
+
+ params = ["num_vertices", "num_edges", "num_cells"]
+ for param in params:
+ assert any([param == str(p) for p in itir_with_tmp.fencil.params])
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
index e851e7b130..5af4605988 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
@@ -18,7 +18,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef, offset
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
@fundef
@@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp):
output_file = sys.argv[1]
prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False)
- generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
index 33c7d5baa7..3e8b88ac66 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
@@ -18,7 +18,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out):
output_file = sys.argv[1]
prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False)
- generated_code = generate(prog, offset_provider={})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
index f7472d4ac3..fdc57449ee 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
@@ -18,7 +18,7 @@
import gt4py.next as gtx
from gt4py.next import Field, field_operator, program
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -47,7 +47,9 @@ def copy_program(
output_file = sys.argv[1]
prog = copy_program.itir
- generated_code = generate(prog, offset_provider={})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
index 1dfd74baca..abc3755dca 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
@@ -19,7 +19,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef, offset
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative
E2V = offset("E2V")
@@ -92,13 +92,20 @@ def mapped_index(_, __) -> int:
output_file = sys.argv[1]
imperative = sys.argv[2].lower() == "true"
+ if imperative:
+ backend = run_gtfn_imperative
+ else:
+ backend = run_gtfn
+
# prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils
prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False)
offset_provider = {
"V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True),
"E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False),
}
- generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative)
+ generated_code = backend.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider=offset_provider, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
index 578a19faab..9755774fd0 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
@@ -19,7 +19,7 @@
from gt4py.next.iterator.runtime import closure, fundef
from gt4py.next.iterator.tracing import trace_fencil_definition
from gt4py.next.iterator.transforms import LiftMode
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x):
prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False)
offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")}
- generated_code = generate(
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
prog,
offset_provider=offset_provider,
- lift_mode=LiftMode.SIMPLE_HEURISTIC,
+ runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC,
column_axis=KDim,
)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py
new file mode 100644
index 0000000000..ba4b1b0cdb
--- /dev/null
+++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py
@@ -0,0 +1,137 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import numpy as np
+import pytest
+
+from gt4py import next as gtx
+from gt4py.next import errors
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import IField, cartesian_case # noqa: F401 # fixtures
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixtures
+ KDim,
+ fieldview_backend,
+)
+
+
+def test_default_backend_is_respected_field_operator(cartesian_case): # noqa: F811 # fixtures
+ """Test that manually calling the field operator without setting the backend raises an error."""
+
+ # Important not to set the backend here!
+ @gtx.field_operator
+ def copy(a: IField) -> IField:
+ return a
+
+ a = cases.allocate(cartesian_case, copy, "a")()
+
+ with pytest.raises(ValueError, match="No backend selected!"):
+ # Calling this should fail if the default backend is respected
+ # due to `fieldview_backend` fixture (dependency of `cartesian_case`)
+ # setting the default backend to something invalid.
+ _ = copy(a, out=a, offset_provider={})
+
+
+def test_default_backend_is_respected_scan_operator(cartesian_case): # noqa: F811 # fixtures
+ """Test that manually calling the scan operator without setting the backend raises an error."""
+
+ # Important not to set the backend here!
+ @gtx.scan_operator(axis=KDim, init=0.0, forward=True)
+ def sum(state: float, a: float) -> float:
+ return state + a
+
+ a = gtx.ones({KDim: 10}, allocator=cartesian_case.backend)
+
+ with pytest.raises(ValueError, match="No backend selected!"):
+ # see comment in field_operator test
+ _ = sum(a, out=a, offset_provider={})
+
+
+def test_default_backend_is_respected_program(cartesian_case): # noqa: F811 # fixtures
+ """Test that manually calling the program without setting the backend raises an error."""
+
+ @gtx.field_operator
+ def copy(a: IField) -> IField:
+ return a
+
+ # Important not to set the backend here!
+ @gtx.program
+ def copy_program(a: IField, b: IField) -> IField:
+ copy(a, out=b)
+
+ a = cases.allocate(cartesian_case, copy_program, "a")()
+ b = cases.allocate(cartesian_case, copy_program, "b")()
+
+ with pytest.raises(ValueError, match="No backend selected!"):
+ # see comment in field_operator test
+ _ = copy_program(a, b, offset_provider={})
+
+
+def test_missing_arg_field_operator(cartesian_case): # noqa: F811 # fixtures
+ """Test that calling a field_operator without required args raises an error."""
+
+ @gtx.field_operator(backend=cartesian_case.backend)
+ def copy(a: IField) -> IField:
+ return a
+
+ a = cases.allocate(cartesian_case, copy, "a")()
+
+ with pytest.raises(errors.MissingArgumentError, match="'out'"):
+ _ = copy(a, offset_provider={})
+
+ with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"):
+ _ = copy(a, out=a)
+
+
+def test_missing_arg_scan_operator(cartesian_case): # noqa: F811 # fixtures
+ """Test that calling a scan_operator without required args raises an error."""
+
+ @gtx.scan_operator(backend=cartesian_case.backend, axis=KDim, init=0.0, forward=True)
+ def sum(state: float, a: float) -> float:
+ return state + a
+
+ a = cases.allocate(cartesian_case, sum, "a")()
+
+ with pytest.raises(errors.MissingArgumentError, match="'out'"):
+ _ = sum(a, offset_provider={})
+
+ with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"):
+ _ = sum(a, out=a)
+
+
+def test_missing_arg_program(cartesian_case): # noqa: F811 # fixtures
+ """Test that calling a program without required args raises an error."""
+
+ @gtx.field_operator
+ def copy(a: IField) -> IField:
+ return a
+
+ a = cases.allocate(cartesian_case, copy, "a")()
+ b = cases.allocate(cartesian_case, copy, cases.RETURN)()
+
+ with pytest.raises(errors.DSLError, match="Invalid call"):
+
+ @gtx.program(backend=cartesian_case.backend)
+ def copy_program(a: IField, b: IField) -> IField:
+ copy(a)
+
+ _ = copy_program(a, offset_provider={})
+
+ with pytest.raises(TypeError, match="'offset_provider'"):
+
+ @gtx.program(backend=cartesian_case.backend)
+ def copy_program(a: IField, b: IField) -> IField:
+ copy(a, out=b)
+
+ _ = copy_program(a)
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
index 86c3c98c62..5c2802f90c 100644
--- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
@@ -323,7 +323,7 @@ def test_update_cartesian_domains():
for a, s in (("JDim", "j"), ("KDim", "k"))
],
)
- actual = update_domains(testee, {"I": gtx.Dimension("IDim")})
+ actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None)
assert actual == expected
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py
new file mode 100644
index 0000000000..ae23becb4c
--- /dev/null
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py
@@ -0,0 +1,161 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+
+from gt4py.eve import SymbolRef
+from gt4py.next.iterator import ir
+from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.iterator.transforms.power_unrolling import PowerUnrolling
+
+
+def test_power_unrolling_zero():
+ pytest.xfail(
+ "Not implementeds we don't have an easy way to determine the type of the one literal (type inference is to expensive)."
+ )
+ testee = im.call("power")("x", 0)
+ expected = im.literal_from_value(1)
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_one():
+ testee = im.call("power")("x", 1)
+ expected = ir.SymRef(id=SymbolRef("x"))
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_two():
+ testee = im.call("power")("x", 2)
+ expected = im.multiplies_("x", "x")
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_two_x_plus_two():
+ testee = im.call("power")(im.plus("x", 2), 2)
+ expected = im.let("power_1", im.plus("x", 2))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2")
+ )
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_two_x_plus_one_times_three():
+ testee = im.call("power")(im.multiplies_(im.plus("x", 1), 3), 2)
+ expected = im.let("power_1", im.multiplies_(im.plus("x", 1), 3))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2")
+ )
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_three():
+ testee = im.call("power")("x", 3)
+ expected = im.multiplies_(im.multiplies_("x", "x"), "x")
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_four():
+ testee = im.call("power")("x", 4)
+ expected = im.let("power_2", im.multiplies_("x", "x"))(im.multiplies_("power_2", "power_2"))
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_five():
+ testee = im.call("power")("x", 5)
+ tmp2 = im.multiplies_("x", "x")
+ expected = im.multiplies_(im.multiplies_(tmp2, tmp2), "x")
+ expected = im.let("power_2", im.multiplies_("x", "x"))(
+ im.multiplies_(im.multiplies_("power_2", "power_2"), "x")
+ )
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_seven():
+ testee = im.call("power")("x", 7)
+ expected = im.call("power")("x", 7)
+
+ actual = PowerUnrolling.apply(testee, max_unroll=5)
+ assert actual == expected
+
+
+def test_power_unrolling_seven_unrolled():
+ testee = im.call("power")("x", 7)
+ expected = im.let("power_2", im.multiplies_("x", "x"))(
+ im.multiplies_(im.multiplies_(im.multiplies_("power_2", "power_2"), "power_2"), "x")
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=7)
+ assert actual == expected
+
+
+def test_power_unrolling_seven_x_plus_one_unrolled():
+ testee = im.call("power")(im.plus("x", 1), 7)
+ expected = im.let("power_1", im.plus("x", 1))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))(
+ im.let("power_4", im.multiplies_("power_2", "power_2"))(
+ im.multiplies_(im.multiplies_("power_4", "power_2"), "power_1")
+ )
+ )
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=7)
+ assert actual == expected
+
+
+def test_power_unrolling_eight():
+ testee = im.call("power")("x", 8)
+ expected = im.call("power")("x", 8)
+
+ actual = PowerUnrolling.apply(testee, max_unroll=5)
+ assert actual == expected
+
+
+def test_power_unrolling_eight_unrolled():
+ testee = im.call("power")("x", 8)
+ expected = im.let("power_2", im.multiplies_("x", "x"))(
+ im.let("power_4", im.multiplies_("power_2", "power_2"))(
+ im.multiplies_("power_4", "power_4")
+ )
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=8)
+ assert actual == expected
+
+
+def test_power_unrolling_eight_x_plus_one_unrolled():
+ testee = im.call("power")(im.plus("x", 1), 8)
+ expected = im.let("power_1", im.plus("x", 1))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))(
+ im.let("power_4", im.multiplies_("power_2", "power_2"))(
+ im.let("power_8", im.multiplies_("power_4", "power_4"))("power_8")
+ )
+ )
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=8)
+ assert actual == expected