From 36b48218981d295fc6e48d353367c820717d9e1d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 17 Aug 2023 13:25:46 +0200 Subject: [PATCH 01/67] deps: Update GridTools C++ dependency to 2.3.1 (#1324) --- .pre-commit-config.yaml | 16 ++++---- constraints.txt | 72 ++++++++++++++++----------------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 72 ++++++++++++++++----------------- src/gt4py/next/otf/workflow.py | 2 +- 7 files changed, 84 insertions(+), 84 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6033de35bb..d70f335bef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -104,7 +104,7 @@ repos: - flake8-eradicate==1.5.0 - flake8-mutable==1.2.0 - flake8-pyproject==1.2.3 - - pygments==2.15.1 + - pygments==2.16.1 ##[[[end]]] # - flake8-rst-docstrings # Disabled for now due to random false positives exclude: | @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.4.1 ========= + #========= FROM constraints.txt: v1.5.0 ========= ##[[[end]]] - rev: v1.4.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.5.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -166,22 +166,22 @@ repos: - boltons==23.0.0 - cached-property==1.5.2 - click==8.1.6 - - cmake==3.27.0 + - cmake==3.27.2 - cytoolz==0.12.2 - deepdiff==6.3.1 - devtools==0.11.0 - frozendict==2.3.8 - - gridtools-cpp==2.3.0 - - importlib-resources==6.0.0 + - gridtools-cpp==2.3.1 + - importlib-resources==6.0.1 - jinja2==3.1.2 - lark==1.1.7 - mako==1.2.4 - - nanobind==1.4.0 + - nanobind==1.5.0 - ninja==1.11.1 - numpy==1.24.4 - packaging==23.1 - pybind11==2.11.1 - - setuptools==68.0.0 + - setuptools==68.1.0 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index 1fddcab390..35e3d9e330 100644 --- a/constraints.txt +++ b/constraints.txt @@ -18,16 +18,16 @@ cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests cffi==1.15.1 # via cryptography -cfgv==3.3.1 # via pre-commit -chardet==5.1.0 # via tox +cfgv==3.4.0 # via pre-commit +chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.0 # via gt4py (pyproject.toml) +cmake==3.27.2 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.2.7 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.2 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.3.0 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in @@ -37,11 +37,11 @@ dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.2 # via hypothesis, pytest +exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.2.0 # via factory-boy +faker==19.3.0 # via factory-boy fastjsonschema==2.18.0 # via nbformat filelock==3.12.2 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -55,19 +55,19 @@ flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in flask==2.3.2 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) -gridtools-cpp==2.3.0 # via gt4py (pyproject.toml) -hypothesis==6.82.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) +hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.26 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx importlib-metadata==6.8.0 # via flask, sphinx -importlib-resources==6.0.0 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest isort==5.12.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.18.4 # via nbformat +jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat jupytext==1.15.0 # via -r requirements-dev.in @@ -79,10 +79,10 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.4.1 # via -r requirements-dev.in +mypy==1.5.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.4.0 # via gt4py (pyproject.toml) -nbformat==5.9.1 # via jupytext +nanobind==1.5.0 # via gt4py (pyproject.toml) +nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit @@ -90,8 +90,8 @@ numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox pathspec==0.11.2 # via black -pip-tools==7.1.0 # via -r requirements-dev.in -pipdeptree==2.12.0 # via -r requirements-dev.in +pip-tools==7.3.0 # via -r requirements-dev.in +pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv pluggy==1.2.0 # via pytest, tox @@ -103,7 +103,7 @@ pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.15.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx +pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx pyproject-api==1.5.3 # via tox pyproject-hooks==1.0.0 # via build pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist @@ -114,11 +114,11 @@ pytest-xdist==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.0 # via jsonschema, jsonschema-specifications +referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.280 # via -r requirements-dev.in +ruff==0.0.284 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis @@ -136,9 +136,9 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.6.4 # via -r requirements-dev.in +tox==4.9.0 # via -r requirements-dev.in traitlets==5.9.0 # via jupyter-core, nbformat -types-aiofiles==23.1.0.5 # via types-all +types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all @@ -148,7 +148,7 @@ types-bleach==6.0.0.4 # via types-all types-boto==2.49.18.9 # via types-all types-cachetools==5.3.0.6 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.14 # via types-jack-client +types-cffi==1.15.1.15 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask @@ -162,8 +162,8 @@ types-dateparser==1.1.4.10 # via types-all types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.4 # via types-all types-deprecated==1.2.9.3 # via types-all -types-docopt==0.6.11.3 # via types-all -types-docutils==0.20.0.1 # via types-all +types-docopt==0.6.11.4 # via types-all +types-docutils==0.20.0.3 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -176,14 +176,14 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.8 # via types-all +types-jack-client==0.5.10.9 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 types-mock==5.1.0.1 # via types-all -types-mypy-extensions==1.0.0.4 # via types-all +types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all @@ -192,9 +192,9 @@ types-pathlib2==2.3.0 # via types-all types-pillow==10.0.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.23.0.2 # via types-all -types-pyaudio==0.2.16.6 # via types-all -types-pycurl==7.45.2.4 # via types-all +types-protobuf==4.24.0.1 # via types-all +types-pyaudio==0.2.16.7 # via types-all +types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all @@ -205,15 +205,15 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.0 # via types-all, types-tzlocal +types-pytz==2023.3.0.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.3 # via types-all +types-redis==4.6.0.4 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.0.0.3 # via types-cffi +types-setuptools==68.1.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.0.0.2 # via types-all types-six==1.16.21.9 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.2 # via pre-commit, tox +virtualenv==20.24.3 # via pre-commit, tox websockets==11.0.3 # via dace -werkzeug==2.3.6 # via flask -wheel==0.41.0 # via astunparse, pip-tools +werkzeug==2.3.7 # via flask +wheel==0.41.1 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.0.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 3a50641c48..17709206a0 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -40,7 +40,7 @@ flake8-pyproject==1.2.2 flake8-rst-docstrings==0.0.14 flake8==5.0.4 frozendict==2.3 -gridtools-cpp==2.3.0 +gridtools-cpp==2.3.1 hypothesis==6.0.0 importlib-resources==5.0;python_version<'3.9' isort==5.10 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 734cdc6cea..a6e5d19d1d 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -38,7 +38,7 @@ flake8-pyproject==1.2.2 flake8-rst-docstrings==0.0.14 flake8==5.0.4 frozendict==2.3 -gridtools-cpp==2.3.0 +gridtools-cpp==2.3.1 hypothesis==6.0.0 importlib-resources==5.0;python_version<'3.9' isort==5.10 diff --git a/pyproject.toml b/pyproject.toml index acc04ee8bb..9422dd4448 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ 'deepdiff>=5.6.0', 'devtools>=0.6', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.0,==2.*', + 'gridtools-cpp>=2.3.1,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index c5168e7f91..a167b2979a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,16 +18,16 @@ cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests cffi==1.15.1 # via cryptography -cfgv==3.3.1 # via pre-commit -chardet==5.1.0 # via tox +cfgv==3.4.0 # via pre-commit +chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.0 # via gt4py (pyproject.toml) +cmake==3.27.2 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.2.7 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.2 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.3.0 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in @@ -37,11 +37,11 @@ dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.2 # via hypothesis, pytest +exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.2.0 # via factory-boy +faker==19.3.0 # via factory-boy fastjsonschema==2.18.0 # via nbformat filelock==3.12.2 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -55,19 +55,19 @@ flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in flask==2.3.2 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) -gridtools-cpp==2.3.0 # via gt4py (pyproject.toml) -hypothesis==6.82.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) +hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.26 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx importlib-metadata==6.8.0 # via flask, sphinx -importlib-resources==6.0.0 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest isort==5.12.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.18.4 # via nbformat +jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat jupytext==1.15.0 # via -r requirements-dev.in @@ -79,10 +79,10 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.4.1 # via -r requirements-dev.in +mypy==1.5.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.4.0 # via gt4py (pyproject.toml) -nbformat==5.9.1 # via jupytext +nanobind==1.5.0 # via gt4py (pyproject.toml) +nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit @@ -90,8 +90,8 @@ numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox pathspec==0.11.2 # via black -pip-tools==7.1.0 # via -r requirements-dev.in -pipdeptree==2.12.0 # via -r requirements-dev.in +pip-tools==7.3.0 # via -r requirements-dev.in +pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv pluggy==1.2.0 # via pytest, tox @@ -103,7 +103,7 @@ pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.15.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx +pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx pyproject-api==1.5.3 # via tox pyproject-hooks==1.0.0 # via build pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist @@ -114,11 +114,11 @@ pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.0 # via jsonschema, jsonschema-specifications +referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.280 # via -r requirements-dev.in +ruff==0.0.284 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis @@ -136,9 +136,9 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.6.4 # via -r requirements-dev.in +tox==4.9.0 # via -r requirements-dev.in traitlets==5.9.0 # via jupyter-core, nbformat -types-aiofiles==23.1.0.5 # via types-all +types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all @@ -148,7 +148,7 @@ types-bleach==6.0.0.4 # via types-all types-boto==2.49.18.9 # via types-all types-cachetools==5.3.0.6 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.14 # via types-jack-client +types-cffi==1.15.1.15 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask @@ -162,8 +162,8 @@ types-dateparser==1.1.4.10 # via types-all types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.4 # via types-all types-deprecated==1.2.9.3 # via types-all -types-docopt==0.6.11.3 # via types-all -types-docutils==0.20.0.1 # via types-all +types-docopt==0.6.11.4 # via types-all +types-docutils==0.20.0.3 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -176,14 +176,14 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.8 # via types-all +types-jack-client==0.5.10.9 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 types-mock==5.1.0.1 # via types-all -types-mypy-extensions==1.0.0.4 # via types-all +types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all @@ -192,9 +192,9 @@ types-pathlib2==2.3.0 # via types-all types-pillow==10.0.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.23.0.2 # via types-all -types-pyaudio==0.2.16.6 # via types-all -types-pycurl==7.45.2.4 # via types-all +types-protobuf==4.24.0.1 # via types-all +types-pyaudio==0.2.16.7 # via types-all +types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all @@ -205,15 +205,15 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.0 # via types-all, types-tzlocal +types-pytz==2023.3.0.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.3 # via types-all +types-redis==4.6.0.4 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.0.0.3 # via types-cffi +types-setuptools==68.1.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.0.0.2 # via types-all types-six==1.16.21.9 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.2 # via pre-commit, tox +virtualenv==20.24.3 # via pre-commit, tox websockets==11.0.3 # via dace -werkzeug==2.3.6 # via flask -wheel==0.41.0 # via astunparse, pip-tools +werkzeug==2.3.7 # via flask +wheel==0.41.1 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.0.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 6420a4ddfa..6b6b91a310 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self: if not dataclasses.is_dataclass(self): raise TypeError(f"{self.__class__} is not a dataclass") assert not isinstance(self, type) - return dataclasses.replace(self, **kwargs) + return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? class ChainableWorkflowMixin(Workflow[StartT, EndT]): From fa306641a53a2d5a927af6cdcf0d9053efede631 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 17 Aug 2023 13:57:46 +0200 Subject: [PATCH 02/67] Replace pybind11 with nanobind (#1299) --- .../otf/binding/{pybind.py => nanobind.py} | 24 +++--- .../compilation/build_systems/cmake_lists.py | 23 +++++- .../compilation/build_systems/compiledb.py | 45 +++++++---- .../program_processors/runners/gtfn_cpu.py | 4 +- ...pybind_build.py => test_nanobind_build.py} | 6 +- .../otf_tests/binding_tests/test_nanobind.py | 24 ++++++ .../otf_tests/binding_tests/test_pybind.py | 77 ------------------- .../build_systems_tests/conftest.py | 4 +- 8 files changed, 93 insertions(+), 114 deletions(-) rename src/gt4py/next/otf/binding/{pybind.py => nanobind.py} (91%) rename tests/next_tests/integration_tests/feature_tests/otf_tests/{test_pybind_build.py => test_nanobind_build.py} (92%) create mode 100644 tests/next_tests/unit_tests/otf_tests/binding_tests/test_nanobind.py delete mode 100644 tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py diff --git a/src/gt4py/next/otf/binding/pybind.py b/src/gt4py/next/otf/binding/nanobind.py similarity index 91% rename from src/gt4py/next/otf/binding/pybind.py rename to src/gt4py/next/otf/binding/nanobind.py index 82b06a31ae..9dccddc012 100644 --- a/src/gt4py/next/otf/binding/pybind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -90,13 +90,15 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) - buffer_t = "pybind11::buffer" + dtype = cpp_interface.render_scalar_type(type_.dtype) + shape = f"nanobind::shape<{', '.join(['nanobind::any'] * ndims)}>" + buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>" return f"std::pair<{buffer_t}, {origin_t}>" elif isinstance(type_, ts.ScalarType): return cpp_interface.render_scalar_type(type_) else: - raise ValueError(f"Type '{type_}' is not supported in pybind11 interfaces.") + raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.") class BindingCodeGenerator(TemplatedGenerator): @@ -131,7 +133,7 @@ class BindingCodeGenerator(TemplatedGenerator): BindingModule = as_jinja( """\ - PYBIND11_MODULE({{name}}, module) { + NB_MODULE({{name}}, module) { module.doc() = "{{doc}}"; {{"\n".join(functions)}} }\ @@ -149,9 +151,7 @@ def visit_BufferSID(self, sid: BufferSID, **kwargs): dims = [self.visit(dim) for dim in sid.dimensions] origin = f"{sid.source_buffer}.second" - as_sid = f"gridtools::as_sid<{cpp_interface.render_scalar_type(sid.scalar_type)},\ - {sid.dimensions.__len__()},\ - gridtools::sid::unknown_kind>({pybuffer})" + as_sid = f"gridtools::nanobind::as_sid({pybuffer})" shifted = f"gridtools::sid::shift_sid_origin({as_sid}, {origin})" renamed = f"gridtools::sid::rename_numbered_dimensions<{', '.join(dims)}>({shifted})" return renamed @@ -187,7 +187,7 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | CompositeS elif isinstance(type_, ts.ScalarType): return name else: - raise ValueError(f"Type '{type_}' is not supported in pybind11 interfaces.") + raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.") def create_bindings( @@ -210,9 +210,10 @@ def create_bindings( file_binding = BindingFile( callee_header_file=f"{program_source.entry_point.name}.{program_source.language_settings.header_extension}", header_files=[ - "pybind11/pybind11.h", - "pybind11/stl.h", - "gridtools/storage/adapter/python_sid_adapter.hpp", + "nanobind/nanobind.h", + "nanobind/stl/tuple.h", + "nanobind/stl/pair.h", + "nanobind/ndarray.h", "gridtools/sid/composite.hpp", "gridtools/sid/unknown_kind.hpp", "gridtools/sid/rename_dimensions.hpp", @@ -221,6 +222,7 @@ def create_bindings( "gridtools/fn/unstructured.hpp", "gridtools/fn/cartesian.hpp", "gridtools/fn/backend/naive.hpp", + "gridtools/storage/adapter/nanobind_adapter.hpp", ], wrapper=WrapperFunction( name=wrapper_name, @@ -258,7 +260,7 @@ def create_bindings( return stages.BindingSource( src, - (interface.LibraryDependency("pybind11", "2.9.2"),), + (interface.LibraryDependency("nanobind", "1.4.0"),), ) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 6117ca377c..ef222341e3 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -82,6 +82,12 @@ def visit_FindDependency(self, dep: FindDependency): import pybind11 return f"find_package(pybind11 CONFIG REQUIRED PATHS {pybind11.get_cmake_dir()} NO_DEFAULT_PATH)" + case "nanobind": + import nanobind + + py = "find_package(Python COMPONENTS Interpreter Development REQUIRED)" + nb = f"find_package(nanobind CONFIG REQUIRED PATHS {nanobind.cmake_dir()} NO_DEFAULT_PATHS)" + return py + "\n" + nb case "gridtools": import gridtools_cpp @@ -93,13 +99,24 @@ def visit_LinkDependency(self, dep: LinkDependency): match dep.name: case "pybind11": lib_name = "pybind11::module" + case "nanobind": + lib_name = "nanobind-static" case "gridtools": lib_name = "GridTools::fn_naive" case _: raise ValueError("Library {name} is not supported".format(name=dep.name)) - return "target_link_libraries({target} PUBLIC {lib})".format( - target=dep.target, lib=lib_name - ) + + cfg = "" + if dep.name == "nanobind": + cfg = "\n".join( + [ + "nanobind_build_library(nanobind-static)", + f"nanobind_compile_options({dep.target})", + f"nanobind_link_options({dep.target})", + ] + ) + lnk = f"target_link_libraries({dep.target} PUBLIC {lib_name})" + return cfg + "\n" + lnk def generate_cmakelists_source( diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 92e5ef1f69..34f2f85081 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -18,6 +18,7 @@ import json import pathlib import re +import shutil import subprocess from typing import Optional @@ -124,6 +125,20 @@ def build(self): self._run_build() def _write_files(self): + def ignore_not_libraries(folder: str, children: list[str]) -> list[str]: + pattern = r"((lib.*\.a)|(.*\.lib))" + libraries = [child for child in children if re.match(pattern, child)] + folders = [child for child in children if (pathlib.Path(folder) / child).is_dir()] + ignored = list(set(children) - set(libraries) - set(folders)) + return ignored + + shutil.copytree( + self.compile_commands_cache.parent, + self.root_path, + ignore=ignore_not_libraries, + dirs_exist_ok=True, + ) + for name, content in self.source_files.items(): (self.root_path / name).write_text(content, encoding="utf-8") @@ -140,7 +155,7 @@ def _run_config(self): compile_db = json.loads(self.compile_commands_cache.read_text()) (self.root_path / "build").mkdir(exist_ok=True) - (self.root_path / "bin").mkdir(exist_ok=True) + (self.root_path / "build" / "bin").mkdir(exist_ok=True) for entry in compile_db: for key, value in entry.items(): @@ -155,7 +170,7 @@ def _run_config(self): build_data.write_data( build_data.BuildData( status=build_data.BuildStatus.CONFIGURED, - module=pathlib.Path(compile_db[-1]["output"]), + module=pathlib.Path(compile_db[-1]["directory"]) / compile_db[-1]["output"], entry_point_name=self.program_name, ), self.root_path, @@ -171,7 +186,7 @@ def _run_build(self): log_file_pointer.write(entry["command"] + "\n") subprocess.check_call( entry["command"], - cwd=self.root_path, + cwd=entry["directory"], shell=True, stdout=log_file_pointer, stderr=log_file_pointer, @@ -251,19 +266,17 @@ def _cc_create_compiledb( program_name=name, ) - prototype_project._write_files() - prototype_project._run_config() + prototype_project.build() log_file = cache_path / "log_compiledb.txt" with log_file.open("w") as log_file_pointer: - commands = json.loads( - subprocess.check_output( - ["ninja", "-t", "compdb"], - cwd=cache_path / "build", - stderr=log_file_pointer, - ).decode("utf-8") - ) + commands_json_str = subprocess.check_output( + ["ninja", "-t", "compdb"], + cwd=cache_path / "build", + stderr=log_file_pointer, + ).decode("utf-8") + commands = json.loads(commands_json_str) compile_db = [ cmd for cmd in commands if name in pathlib.Path(cmd["file"]).stem and cmd["command"] @@ -272,10 +285,10 @@ def _cc_create_compiledb( assert compile_db for entry in compile_db: - entry["directory"] = "$SRC_PATH" + entry["directory"] = entry["directory"].replace(str(cache_path), "$SRC_PATH") entry["command"] = ( entry["command"] - .replace(f"CMakeFiles/{name}.dir", "build") + .replace(f"CMakeFiles/{name}.dir", ".") .replace(str(cache_path), "$SRC_PATH") .replace(f"{name}.cpp", "$BINDINGS_FILE") .replace(f"{name}", "$NAME") @@ -283,13 +296,13 @@ def _cc_create_compiledb( ) entry["file"] = ( entry["file"] - .replace(f"CMakeFiles/{name}.dir", "build") + .replace(f"CMakeFiles/{name}.dir", ".") .replace(str(cache_path), "$SRC_PATH") .replace(f"{name}.cpp", "$BINDINGS_FILE") ) entry["output"] = ( entry["output"] - .replace(f"CMakeFiles/{name}.dir", "build") + .replace(f"CMakeFiles/{name}.dir", ".") .replace(f"{name}.cpp", "$BINDINGS_FILE") .replace(f"{name}", "$NAME") ) diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn_cpu.py index 0aa755cc92..195126a6ba 100644 --- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py +++ b/src/gt4py/next/program_processors/runners/gtfn_cpu.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import content_hash from gt4py.next import common from gt4py.next.otf import languages, recipes, stages, workflow -from gt4py.next.otf.binding import cpp_interface, pybind +from gt4py.next.otf.binding import cpp_interface, nanobind from gt4py.next.otf.compilation import cache, compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors import otf_compile_executor @@ -102,7 +102,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: GTFN_DEFAULT_WORKFLOW = recipes.OTFCompileWorkflow( translation=GTFN_DEFAULT_TRANSLATION_STEP, - bindings=pybind.bind_source, + bindings=nanobind.bind_source, compilation=GTFN_DEFAULT_COMPILE_STEP, decoration=convert_args, ) diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py similarity index 92% rename from tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py rename to tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 90b471dad1..f24dc4bc59 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_pybind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -17,7 +17,7 @@ import numpy as np from gt4py.next.otf import workflow -from gt4py.next.otf.binding import pybind +from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler from gt4py.next.otf.compilation.build_systems import cmake, compiledb @@ -28,7 +28,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") - build_the_program = workflow.make_step(pybind.bind_source).chain( + build_the_program = workflow.make_step(nanobind.bind_source).chain( compiler.Compiler( cache_strategy=cache.Strategy.SESSION, builder_factory=cmake.CMakeFactory() ), @@ -46,7 +46,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") - build_the_program = workflow.make_step(pybind.bind_source).chain( + build_the_program = workflow.make_step(nanobind.bind_source).chain( compiler.Compiler( cache_strategy=cache.Strategy.SESSION, builder_factory=compiledb.CompiledbFactory(), diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_nanobind.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_nanobind.py new file mode 100644 index 0000000000..b081ccd138 --- /dev/null +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_nanobind.py @@ -0,0 +1,24 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next.otf.binding import nanobind + +from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( + program_source_example, +) + + +def test_bindings(program_source_example): + module = nanobind.create_bindings(program_source_example) + assert module.library_deps[0].name == "nanobind" diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py deleted file mode 100644 index 783ff8bdfd..0000000000 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_pybind.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from gt4py.next.otf.binding import interface, pybind - -from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( - program_source_example, -) - - -def test_bindings(program_source_example): - module = pybind.create_bindings(program_source_example) - expected_src = interface.format_source( - program_source_example.language_settings, - """\ - #include "stencil.cpp.inc" - - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - - decltype(auto) stencil_wrapper( - std::pair> buf, - std::tuple>, - std::pair>> - tup, - float sc) { - return stencil( - gridtools::sid::rename_numbered_dimensions< - generated::I_t, generated::J_t>(gridtools::sid::shift_sid_origin( - gridtools::as_sid(buf.first), - buf.second)), - gridtools::sid::composite::keys, - gridtools::integral_constant>:: - make_values( - gridtools::sid::rename_numbered_dimensions( - gridtools::sid::shift_sid_origin( - gridtools::as_sid( - gridtools::tuple_util::get<0>(tup).first), - gridtools::tuple_util::get<0>(tup).second)), - gridtools::sid::rename_numbered_dimensions( - gridtools::sid::shift_sid_origin( - gridtools::as_sid( - gridtools::tuple_util::get<1>(tup).first), - gridtools::tuple_util::get<1>(tup).second))), - sc); - } - - PYBIND11_MODULE(stencil, module) { - module.doc() = ""; - module.def("stencil", &stencil_wrapper, ""); - }\ - """, - ) - assert module.library_deps[0].name == "pybind11" - assert module.source_code == expected_src diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 4dc3215a07..1fab2643b5 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -21,7 +21,7 @@ import gt4py.next as gtx import gt4py.next.type_system.type_specifications as ts from gt4py.next.otf import languages, stages -from gt4py.next.otf.binding import cpp_interface, interface, pybind +from gt4py.next.otf.binding import cpp_interface, interface, nanobind from gt4py.next.otf.compilation import cache @@ -99,7 +99,7 @@ def program_source_example(): def compilable_source_example(program_source_example): return stages.CompilableSource( program_source=program_source_example, - binding_source=pybind.create_bindings(program_source_example), + binding_source=nanobind.create_bindings(program_source_example), ) From 3aaca75613e74a6f19348ad9e960d71676ae1e62 Mon Sep 17 00:00:00 2001 From: Samuel Date: Tue, 22 Aug 2023 17:49:03 +0200 Subject: [PATCH 03/67] feat[next]: embedded.ndarray_field: field slicing and intersection (#1315) Enable absolute and relative slicing of fields, as well as binary operations on fields with a different domain in embedded. --------- Co-authored-by: Hannes Vogt Co-authored-by: Hannes Vogt --- src/gt4py/next/common.py | 58 +++- src/gt4py/next/embedded/nd_array_field.py | 260 ++++++++++++++-- src/gt4py/next/ffront/fbuiltins.py | 4 +- .../embedded_tests/test_nd_array_field.py | 289 +++++++++++++++++- tests/next_tests/unit_tests/test_common.py | 10 + 5 files changed, 575 insertions(+), 46 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 36167532f2..e06f9c54b1 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,7 +20,8 @@ import functools import sys from collections.abc import Sequence, Set -from typing import overload +from types import EllipsisType +from typing import TypeGuard, overload import numpy as np import numpy.typing as npt @@ -36,7 +37,6 @@ Protocol, TypeAlias, TypeVar, - Union, extended_runtime_checkable, final, runtime_checkable, @@ -50,11 +50,11 @@ class Infinity(int): @classmethod - def positive(cls) -> "Infinity": + def positive(cls) -> Infinity: return cls(sys.maxsize) @classmethod - def negative(cls) -> "Infinity": + def negative(cls) -> Infinity: return cls(-sys.maxsize) @@ -77,11 +77,6 @@ def __str__(self): return f'Dimension(value="{self.value}", kind={self.kind})' -DomainLike: TypeAlias = Union[ - Sequence[Dimension], Dimension, str -] # TODO(havogt): revisit once embedded implementation is concluded - - @dataclasses.dataclass(frozen=True) class UnitRange(Sequence[int], Set[int]): """Range from `start` to `stop` with step size one.""" @@ -95,6 +90,10 @@ def __post_init__(self): object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) + @classmethod + def infinity(cls) -> UnitRange: + return cls(Infinity.negative(), Infinity.positive()) + def __len__(self) -> int: if Infinity.positive() in (abs(self.start), abs(self.stop)): return Infinity.positive() @@ -137,7 +136,19 @@ def __and__(self, other: Set[Any]) -> UnitRange: raise NotImplementedError("Can only find the intersection between UnitRange instances.") +DomainRange: TypeAlias = UnitRange | int NamedRange: TypeAlias = tuple[Dimension, UnitRange] +NamedIndex: TypeAlias = tuple[Dimension, int] +DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] +FieldSlice: TypeAlias = ( + DomainSlice + | tuple[slice | int | EllipsisType, ...] + | slice + | int + | EllipsisType + | NamedRange + | NamedIndex +) @dataclasses.dataclass(frozen=True) @@ -149,6 +160,11 @@ def __post_init__(self): if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") + if len(self.dims) != len(self.ranges): + raise ValueError( + f"Number of provided dimensions ({len(self.dims)}) does not match number of provided ranges ({len(self.ranges)})." + ) + def __len__(self) -> int: return len(self.ranges) @@ -242,7 +258,7 @@ def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: "DomainLike") -> Field: + def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT: ... # Operators @@ -251,7 +267,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: "DomainLike") -> Field: + def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -307,6 +323,12 @@ def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... +def is_field( + v: Any, +) -> TypeGuard[Field]: # this function is introduced to localize the `type: ignore`` + return isinstance(v, Field) # type: ignore[misc] # we use extended_runtime_checkable + + class FieldABC(Field[DimsT, core_defs.ScalarT]): """Abstract base class for implementations of the :class:`Field` protocol.""" @@ -445,3 +467,17 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: ) return topologically_sorted_list + + +def is_named_range(v: Any) -> TypeGuard[NamedRange]: + return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], UnitRange) + + +def is_named_index(v: Any) -> TypeGuard[NamedIndex]: + return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], int) + + +def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: + return isinstance(index, Sequence) and all( + is_named_range(idx) or is_named_index(idx) for idx in index + ) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 66c097951a..9813efdd22 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -16,14 +16,17 @@ import dataclasses import functools -from collections.abc import Callable -from types import ModuleType -from typing import ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, overload +import itertools +from collections.abc import Callable, Sequence +from types import EllipsisType, ModuleType +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast, overload import numpy as np from numpy import typing as npt +from gt4py._core import definitions as core_defs from gt4py.next import common +from gt4py.next.ffront import fbuiltins try: @@ -37,12 +40,6 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] -from gt4py._core import definitions -from gt4py._core.definitions import ScalarT -from gt4py.next.common import DimsT, Domain -from gt4py.next.ffront import fbuiltins - - def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: def _builtin_unary_op(a: _BaseNdArrayField) -> common.Field: xp = a.__class__.array_ns @@ -61,12 +58,16 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: op = getattr(xp, array_builtin_name) if hasattr(b, "__gt_builtin_func__"): # isinstance(b, common.Field): if not a.domain == b.domain: - raise NotImplementedError( - f"support for different domain not implemented: {a.domain}, {b.domain}" - ) + domain_intersection = a.domain & b.domain + a_broadcasted = _broadcast(a, domain_intersection.dims) + b_broadcasted = _broadcast(b, domain_intersection.dims) + a_slices = _get_slices_from_domain_slice(a_broadcasted.domain, domain_intersection) + b_slices = _get_slices_from_domain_slice(b_broadcasted.domain, domain_intersection) + new_data = op(a_broadcasted.ndarray[a_slices], b_broadcasted.ndarray[b_slices]) + return a.__class__.from_array(new_data, domain=domain_intersection) new_data = op(a.ndarray, xp.asarray(b.ndarray)) else: - assert isinstance(b, definitions.SCALAR_TYPES) + assert isinstance(b, core_defs.SCALAR_TYPES) new_data = op(a.ndarray, b) return a.__class__.from_array(new_data, domain=a.domain) @@ -75,13 +76,13 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: return _builtin_binary_op -_Value: TypeAlias = common.Field | ScalarT +_Value: TypeAlias = common.Field | core_defs.ScalarT _P = ParamSpec("_P") _R = TypeVar("_R", _Value, tuple[_Value, ...]) @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[DimsT, ScalarT]): +class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): """ Shared field implementation for NumPy-like fields. @@ -91,9 +92,9 @@ class _BaseNdArrayField(common.FieldABC[DimsT, ScalarT]): function via its namespace. """ - _domain: Domain - _ndarray: definitions.NDArrayObject - _value_type: type[ScalarT] + _domain: common.Domain + _ndarray: core_defs.NDArrayObject + _value_type: type[core_defs.ScalarT] array_ns: ClassVar[ ModuleType @@ -129,24 +130,25 @@ def register_builtin_func( return cls._builtin_func_map.setdefault(op, op_func) @property - def domain(self) -> Domain: + def domain(self) -> common.Domain: return self._domain @property - def ndarray(self) -> definitions.NDArrayObject: + def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray @property - def value_type(self) -> type[definitions.ScalarT]: + def value_type(self) -> type[core_defs.ScalarT]: return self._value_type @classmethod def from_array( cls, - data: npt.ArrayLike, + data: npt.ArrayLike + | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike /, *, - domain: Domain, + domain: common.Domain, value_type: Optional[type] = None, ) -> _BaseNdArrayField: xp = cls.array_ns @@ -157,11 +159,14 @@ def from_array( value_type = array.dtype.type # TODO add support for Dimensions as value_type - assert issubclass(array.dtype.type, definitions.SCALAR_TYPES) + assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) assert all(isinstance(d, common.Dimension) for d, r in domain), domain assert len(domain) == array.ndim - assert all(len(nr[1]) == s for nr, s in zip(domain, array.shape)) + assert all( + len(nr[1]) == s or (s == 1 and nr[1] == common.UnitRange.infinity()) + for nr, s in zip(domain, array.shape) + ) assert value_type is not None # for mypy return cls(domain, array, value_type) @@ -169,13 +174,8 @@ def from_array( def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() - def restrict(self: _BaseNdArrayField, domain) -> _BaseNdArrayField: - raise NotImplementedError() - __call__ = None # type: ignore[assignment] # TODO: remap - __getitem__ = None # type: ignore[assignment] # TODO: restrict - __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") @@ -194,6 +194,79 @@ def restrict(self: _BaseNdArrayField, domain) -> _BaseNdArrayField: __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") + def __getitem__(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: + if ( + not isinstance(index, tuple) + and not common.is_domain_slice(index) + or common.is_named_index(index) + or common.is_named_range(index) + ): + index = cast(common.FieldSlice, (index,)) + + if common.is_domain_slice(index): + return self._getitem_absolute_slice(index) + + assert isinstance(index, tuple) + if all(isinstance(idx, (slice, int)) or idx is Ellipsis for idx in index): + return self._getitem_relative_slice(index) + + raise IndexError(f"Unsupported index type: {index}") + + restrict = ( + __getitem__ # type:ignore[assignment] # TODO(havogt) I don't see the problem that mypy has + ) + + def _getitem_absolute_slice( + self, index: common.DomainSlice + ) -> common.Field | core_defs.ScalarT: + slices = _get_slices_from_domain_slice(self.domain, index) + new_ranges = [] + new_dims = [] + new = self.ndarray[slices] + + for i, dim in enumerate(self.domain.dims): + if (pos := _find_index_of_dim(dim, index)) is not None: + index_or_range = index[pos][1] + if isinstance(index_or_range, common.UnitRange): + new_ranges.append(index_or_range) + new_dims.append(dim) + else: + # dimension not mentioned in slice + new_ranges.append(self.domain.ranges[i]) + new_dims.append(dim) + + new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + + if len(new_domain) == 0: + assert core_defs.is_scalar_type(new) + return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + else: + return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + + def _getitem_relative_slice( + self, indices: tuple[slice | int | EllipsisType, ...] + ) -> common.Field | core_defs.ScalarT: + new = self.ndarray[indices] + new_dims = [] + new_ranges = [] + + for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... + self.domain, _expand_ellipsis(indices, len(self.domain)), fillvalue=slice(None) + ): + if isinstance(idx, slice): + new_dims.append(dim) + new_ranges.append(_slice_range(rng, idx)) + else: + assert isinstance(idx, int) # not in new_domain + + new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + + if len(new_domain) == 0: + assert core_defs.is_scalar_type(new), new + return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + else: + return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + # -- Specialized implementations for intrinsic operations on array fields -- @@ -234,7 +307,6 @@ class NumPyArrayField(_BaseNdArrayField): common.field.register(np.ndarray, NumPyArrayField.from_array) - # CuPy if cp: _nd_array_implementations.append(cp) @@ -254,3 +326,129 @@ class JaxArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = jnp common.field.register(jnp.ndarray, JaxArrayField.from_array) + + +def _find_index_of_dim( + dim: common.Dimension, + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], +) -> Optional[int]: + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i + return None + + +def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: + domain_slice: list[slice | None] = [] + new_domain_dims = [] + new_domain_ranges = [] + for dim in new_dimensions: + if (pos := _find_index_of_dim(dim, field.domain)) is not None: + domain_slice.append(slice(None)) + new_domain_dims.append(dim) + new_domain_ranges.append(field.domain[pos][1]) + else: + domain_slice.append(np.newaxis) + new_domain_dims.append(dim) + new_domain_ranges.append( + common.UnitRange(common.Infinity.negative(), common.Infinity.positive()) + ) + return common.field( + field.ndarray[tuple(domain_slice)], + domain=common.Domain(tuple(new_domain_dims), tuple(new_domain_ranges)), + ) + + +def _builtins_broadcast( + field: common.Field | core_defs.Scalar, new_dimensions: tuple[common.Dimension, ...] +) -> common.Field: # separated for typing reasons + if common.is_field(field): + return _broadcast(field, new_dimensions) + raise AssertionError("Scalar case not reachable from `fbuiltins.broadcast`.") + + +_BaseNdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) + + +def _get_slices_from_domain_slice( + domain: common.Domain, + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], +) -> tuple[slice | int | None, ...]: + """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. + + This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided + named ranges or indices specify the dimensions and ranges of the sub-arrays to be extracted. + + Args: + domain (common.Domain): The Domain object representing the original field. + domain_slice (DomainSlice): A sequence of dimension names and associated ranges. + + Returns: + tuple[slice | int | None, ...]: A tuple of slices representing the sub-array extraction along each dimension + specified in the Domain. If a dimension is not included in the named indices + or ranges, a None is used to indicate expansion along that axis. + """ + slice_indices: list[slice | int | None] = [] + + for pos_old, (dim, _) in enumerate(domain): + if (pos := _find_index_of_dim(dim, domain_slice)) is not None: + index_or_range = domain_slice[pos][1] + slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) + else: + slice_indices.append(slice(None)) + return tuple(slice_indices) + + +def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> slice | int: + """Compute a slice or integer based on the provided range, domain, and position. + + Args: + rng (DomainRange): The range to be computed as a slice or integer. + domain (common.Domain): The domain containing dimension information. + pos (int): The position of the dimension in the domain. + + Returns: + slice | int: Slice if `new_rng` is a UnitRange, otherwise an integer. + + Raises: + ValueError: If `new_rng` is not an integer or a UnitRange. + """ + if isinstance(rng, common.UnitRange): + if domain.ranges[pos] == common.UnitRange.infinity(): + return slice(None) + else: + return slice( + rng.start - domain.ranges[pos].start, + rng.stop - domain.ranges[pos].start, + ) + elif isinstance(rng, int): + return rng - domain.ranges[pos].start + else: + raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") + + +def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: + # handle slice(None) case + if slice_obj == slice(None): + return common.UnitRange(input_range.start, input_range.stop) + + start = ( + input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop + ) + (slice_obj.start or 0) + stop = ( + input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop + ) + (slice_obj.stop or len(input_range)) + + return common.UnitRange(start, stop) + + +def _expand_ellipsis( + indices: tuple[int | slice | EllipsisType, ...], target_size: int +) -> tuple[int | slice, ...]: + expanded_indices: list[int | slice] = [] + for idx in indices: + if idx is Ellipsis: + expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) + else: + expanded_indices.append(idx) + return tuple(expanded_indices) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 2129bd6dcd..d1d403c407 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -65,7 +65,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ( ts.FunctionType ) # our type of type is currently represented by the type constructor function - elif t is Tuple: + elif t is Tuple or (hasattr(t, "__origin__") and t.__origin__ is tuple): return ts.TupleType elif hasattr(t, "__origin__") and t.__origin__ is Union: types = [_type_conversion_helper(e) for e in t.__args__] # type: ignore[attr-defined] @@ -161,7 +161,7 @@ def min_over( @builtin_function -def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple, /) -> Field: +def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /) -> Field: raise NotImplementedError() diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 6df39654cf..a2aa3112bd 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import dataclasses import itertools import math import operator @@ -20,13 +20,20 @@ import numpy as np import pytest -from gt4py.next import common +from gt4py.next import Dimension, common +from gt4py.next.common import Domain, UnitRange from gt4py.next.embedded import nd_array_field +from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice, _slice_range from gt4py.next.ffront import fbuiltins from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data +IDim = Dimension("IDim") +JDim = Dimension("JDim") +KDim = Dimension("KDim") + + @pytest.fixture(params=nd_array_field._nd_array_implementations) def nd_array_implementation(request): yield request.param @@ -79,6 +86,30 @@ def test_binary_ops(binary_op, nd_array_implementation): assert np.allclose(result.ndarray, expected) +@pytest.mark.parametrize( + "dims,expected_indices", + [ + ((IDim,), (slice(5, 10), None)), + ((JDim,), (None, slice(5, 10))), + ], +) +def test_binary_operations_with_intersection(binary_op, dims, expected_indices): + arr1 = np.arange(10) + arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) + + arr2 = np.ones((5, 5)) + arr2_domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 10), UnitRange(5, 10))) + + field1 = common.field(arr1, domain=arr1_domain) + field2 = common.field(arr2, domain=arr2_domain) + + op_result = binary_op(field1, field2) + expected_result = binary_op(arr1[expected_indices[0], expected_indices[1]], arr2) + + assert op_result.ndarray.shape == (5, 5) + assert np.allclose(op_result.ndarray, expected_result) + + @pytest.fixture( params=itertools.product( nd_array_field._nd_array_implementations, nd_array_field._nd_array_implementations @@ -125,3 +156,257 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: result = fma(field_inp_a, field_inp_b, field_inp_c) assert np.allclose(result.ndarray, expected) + + +@pytest.mark.parametrize( + "new_dims,field,expected_domain", + [ + ( + ( + (IDim,), + common.field( + np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + ), + Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), + ) + ), + ( + ( + (IDim, JDim), + common.field( + np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + ), + Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinity())), + ) + ), + ( + ( + (IDim, JDim), + common.field( + np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) + ), + Domain(dims=(IDim, JDim), ranges=(UnitRange.infinity(), UnitRange(0, 10))), + ) + ), + ( + ( + (IDim, JDim, KDim), + common.field( + np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) + ), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange.infinity(), UnitRange(0, 10), UnitRange.infinity()), + ), + ) + ), + ], +) +def test_field_broadcast(new_dims, field, expected_domain): + result = fbuiltins.broadcast(field, new_dims) + assert result.domain == expected_domain + + +@pytest.mark.parametrize( + "domain_slice", + [ + ((IDim, UnitRange(0, 10)),), + common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), + ], +) +def test_get_slices_with_named_indices_3d_to_1d(domain_slice): + field_domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + ) + slices = _get_slices_from_domain_slice(field_domain, domain_slice) + assert slices == (slice(0, 10, None), slice(None), slice(None)) + + +def test_get_slices_with_named_index(): + field_domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + ) + named_index = ((IDim, UnitRange(0, 10)), (JDim, 2), (KDim, 3)) + slices = _get_slices_from_domain_slice(field_domain, named_index) + assert slices == (slice(0, 10, None), 2, 3) + + +def test_get_slices_invalid_type(): + field_domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + ) + new_domain = ((IDim, "1"),) + with pytest.raises(ValueError): + _get_slices_from_domain_slice(field_domain, new_domain) + + +@pytest.mark.parametrize( + "domain_slice,expected_dimensions,expected_shape", + [ + ( + ( + (IDim, UnitRange(7, 9)), + (JDim, UnitRange(8, 10)), + ), + (IDim, JDim, KDim), + (2, 2, 15), + ), + ( + ( + (IDim, UnitRange(7, 9)), + (KDim, UnitRange(12, 20)), + ), + (IDim, JDim, KDim), + (2, 10, 8), + ), + (common.Domain(dims=(IDim,), ranges=(UnitRange(7, 9),)), (IDim, JDim, KDim), (2, 10, 15)), + (((IDim, 8),), (JDim, KDim), (10, 15)), + (((JDim, 9),), (IDim, KDim), (5, 15)), + (((KDim, 11),), (IDim, JDim), (5, 10)), + ( + ( + (IDim, 8), + (JDim, UnitRange(8, 10)), + ), + (JDim, KDim), + (2, 15), + ), + ((IDim, 1), (JDim, KDim), (10, 15)), + ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), + ], +) +def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common.field(np.ones((5, 10, 15)), domain=domain) + indexed_field = field[domain_slice] + + assert isinstance(indexed_field, common.Field) + assert indexed_field.ndarray.shape == expected_shape + assert indexed_field.domain.dims == expected_dimensions + + +def test_absolute_indexing_value_return(): + domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) + field = common.field(np.ones((10, 10), dtype=np.int32), domain=domain) + + named_index = ((IDim, 2), (JDim, 4)) + value = field[named_index] + + assert isinstance(value, np.int32) + assert value == 1 + + +@pytest.mark.parametrize( + "index, expected_shape, expected_domain", + [ + ( + (slice(None, 5), slice(None, 2)), + (5, 2), + Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 4))), + ), + ((slice(None, 5),), (5, 10), Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 12)))), + ((Ellipsis, 1), (10,), Domain((IDim,), (UnitRange(5, 15),))), + ( + (slice(2, 3), slice(5, 7)), + (1, 2), + Domain((IDim, JDim), (UnitRange(7, 8), UnitRange(7, 9))), + ), + ( + (slice(1, 2), 0), + (1,), + Domain((IDim,), (UnitRange(6, 7),)), + ), + ], +) +def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): + domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) + field = common.field(np.ones((10, 10)), domain=domain) + indexed_field = field[index] + + assert isinstance(indexed_field, common.Field) + assert indexed_field.ndarray.shape == expected_shape + assert indexed_field.domain == expected_domain + + +@pytest.mark.parametrize( + "index, expected_shape, expected_domain", + [ + ((1, slice(None), 2), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ( + (slice(None), slice(None), 2), + (10, 15), + Domain((IDim, JDim), (UnitRange(5, 15), UnitRange(10, 25))), + ), + ( + (slice(None),), + (10, 15, 10), + Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + ), + ( + (slice(None), slice(None), slice(None)), + (10, 15, 10), + Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + ), + ( + (slice(None)), + (10, 15, 10), + Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + ), + ((0, Ellipsis, 0), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ( + Ellipsis, + (10, 15, 10), + Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + ), + ], +) +def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) + ) + field = common.field(np.ones((10, 15, 10)), domain=domain) + indexed_field = field[index] + + assert isinstance(indexed_field, common.Field) + assert indexed_field.ndarray.shape == expected_shape + assert indexed_field.domain == expected_domain + + +@pytest.mark.parametrize( + "index, expected_value", + [((1, 0), 10), ((0, 1), 1)], +) +def test_relative_indexing_value_return(index, expected_value): + domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) + field = common.field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) + indexed_field = field[index] + + assert indexed_field == expected_value + + +@pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]]) +def test_relative_indexing_out_of_bounds(lazy_slice): + domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) + field = common.field(np.ones((10, 10)), domain=domain) + + with pytest.raises(IndexError): + lazy_slice(field) + + +@pytest.mark.parametrize("index", [IDim, "1", (IDim, JDim)]) +def test_field_unsupported_index(index): + domain = common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + field = common.field(np.ones((10,)), domain=domain) + with pytest.raises(IndexError, match="Unsupported index type"): + field[index] + + +def test_slice_range(): + input_range = UnitRange(2, 10) + slice_obj = slice(2, -2) + expected = UnitRange(4, 8) + + result = _slice_range(input_range, slice_obj) + assert result == expected diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 75c793f914..8cdc96254c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -252,6 +252,16 @@ def test_domain_repeat_dims(): Domain(dims, ranges) +def test_domain_dims_ranges_length_mismatch(): + with pytest.raises( + ValueError, + match=r"Number of provided dimensions \(\d+\) does not match number of provided ranges \(\d+\)", + ): + dims = [Dimension("X"), Dimension("Y"), Dimension("Z")] + ranges = [UnitRange(0, 1), UnitRange(0, 1)] + Domain(dims=dims, ranges=ranges) + + def dimension_promotion_cases() -> ( list[tuple[list[list[Dimension]], list[Dimension] | None, None | Pattern]] ): From 3497cb8270689246e6b00f5214033ae9e0813de0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 23 Aug 2023 10:30:19 +0200 Subject: [PATCH 04/67] feature[next]: astype for scalars (#1326) Add support for casting scalars, whereas previously only fields were supported. Additionally arguments are counted from one instead of zero. --- src/gt4py/next/ffront/fbuiltins.py | 2 +- .../ffront/foast_passes/type_deduction.py | 67 +++++++++++-------- src/gt4py/next/type_system/type_info.py | 31 ++++++++- .../ffront_tests/test_execution.py | 13 +++- .../ffront_tests/test_type_deduction.py | 45 ++++++++++--- 5 files changed, 115 insertions(+), 43 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d1d403c407..b6831b35df 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -176,7 +176,7 @@ def where( @builtin_function -def astype(field: Field, type_: type, /) -> Field: +def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index bd7eddbcdd..605b83a5f0 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -28,31 +28,32 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -def boolified_type(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.FieldType: +def with_altered_scalar_kind( + type_spec: ts.TypeSpec, new_scalar_kind: ts.ScalarKind +) -> ts.ScalarType | ts.FieldType: """ - Create a new symbol type from a symbol type, replacing the data type with ``bool``. + Given a scalar or field type, return a type with different scalar kind. Examples: --------- >>> from gt4py.next import Dimension >>> scalar_t = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> print(boolified_type(scalar_t)) + >>> print(with_altered_scalar_kind(scalar_t, ts.ScalarKind.BOOL)) bool - >>> field_t = ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)) - >>> print(boolified_type(field_t)) - Field[[I], bool] + >>> field_t = ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + >>> print(with_altered_scalar_kind(field_t, ts.ScalarKind.FLOAT32)) + Field[[I], float32] """ - shape = None - if type_info.is_concrete(symbol_type): - shape = type_info.extract_dtype(symbol_type).shape - scalar_bool = ts.ScalarType(kind=ts.ScalarKind.BOOL, shape=shape) - type_class = type_info.type_class(symbol_type) - if type_class is ts.ScalarType: - return scalar_bool - elif type_class is ts.FieldType: - return ts.FieldType(dtype=scalar_bool, dims=type_info.extract_dims(symbol_type)) - raise ValueError(f"Can not boolify type {symbol_type}!") + if isinstance(type_spec, ts.FieldType): + return ts.FieldType( + dims=type_spec.dims, + dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape), + ) + elif isinstance(type_spec, ts.ScalarType): + return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) + else: + raise ValueError(f"Expected field or scalar type, but got {type_spec}.") def construct_tuple_type( @@ -563,7 +564,10 @@ def _deduce_compare_type( try: # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion - return type_info.promote(boolified_type(left.type), boolified_type(right.type)) + return type_info.promote( + with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL), + with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL), + ) except ValueError as ex: raise errors.DSLError( node.location, @@ -762,7 +766,9 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: ): return_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) elif func_name in fbuiltins.UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES: - return_type = boolified_type(cast(ts.FieldType | ts.ScalarType, node.args[0].type)) + return_type = with_altered_scalar_kind( + cast(ts.FieldType | ts.ScalarType, node.args[0].type), ts.ScalarKind.BOOL + ) elif func_name in fbuiltins.BINARY_MATH_NUMBER_BUILTIN_NAMES: try: return_type = type_info.promote( @@ -817,19 +823,22 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: - casted_obj_type = node.args[0].type - dtype_obj = node.args[1] - assert isinstance(dtype_obj, foast.Name) - dtype_obj_type = dtype_obj.type - assert isinstance(dtype_obj_type, ts.FunctionType) - assert dtype_obj.id in fbuiltins.TYPE_BUILTIN_NAMES - assert isinstance(casted_obj_type, ts.FieldType) - assert type_info.is_arithmetic(casted_obj_type) or type_info.is_logical(casted_obj_type) + value, new_type = node.args + assert isinstance( + value.type, (ts.FieldType, ts.ScalarType) + ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ + kind.name for kind in ts.ScalarKind + ]: + raise errors.DSLError( + node.location, + f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", + ) - return_type = ts.FieldType( - dims=casted_obj_type.dims, - dtype=self.visit(dtype_obj_type).returns, + return_type = with_altered_scalar_kind( + value.type, getattr(ts.ScalarKind, new_type.id.upper()) ) + return foast.Call( func=node.func, args=node.args, diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5a22824740..564df7fd1a 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -23,6 +23,31 @@ from gt4py.next.type_system import type_specifications as ts +def _number_to_ordinal_number(number: int) -> str: + """ + Convert number into ordinal number. + + >>> for i in range(0, 5): + ... print(_number_to_ordinal_number(i)) + 0th + 1st + 2nd + 3rd + 4th + """ + number_as_string = str(number) + if len(number_as_string) > 1 and number_as_string[-2] == "1": + return number_as_string + "th" + last_digit = number_as_string[-1] + if last_digit == "1": + return number_as_string + "st" + if last_digit == "2": + return number_as_string + "nd" + if last_digit == "3": + return number_as_string + "rd" + return number_as_string + "th" + + def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: """Figure out if the foast type is completely deduced.""" if isinstance(symbol_type, ts.DeferredType): @@ -612,7 +637,7 @@ def function_signature_incompatibilities_func( # noqa: C901 and not is_concretizable(a_arg, to_type=b_arg) ): if i < len(func_type.pos_only_args): - arg_repr = f"{i}-th argument" + arg_repr = f"{_number_to_ordinal_number(i+1)} argument" else: arg_repr = f"argument `{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}`" yield f"Expected {arg_repr} to be of type `{a_arg}`, but got `{b_arg}`." @@ -631,11 +656,11 @@ def function_signature_incompatibilities_field( kwargs: dict[str, ts.TypeSpec], ) -> Iterator[str]: if len(args) != 1: - yield f"Function takes 1 argument(s), but {len(args)} were given." + yield f"Function takes 1 argument, but {len(args)} were given." return if not isinstance(args[0], ts.OffsetType): - yield f"Expected 0-th argument to be of type {ts.OffsetType}, but got {args[0]}." + yield f"Expected first argument to be of type {ts.OffsetType}, but got {args[0]}." return if kwargs: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8e4483fab7..0268411c71 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -339,7 +339,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) -def test_astype_bool(cartesian_case): # noqa: F811 # fixtures +def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: b = astype(a, bool) @@ -353,6 +353,17 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: ) +@pytest.mark.parametrize("inp", [0.0, 2.0]) +def test_astype_bool_scalar(cartesian_case, inp): # noqa: F811 # fixtures + @gtx.field_operator + def testee(inp: float) -> gtx.Field[[IDim], bool]: + return broadcast(astype(inp, bool), (IDim,)) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, inp, out=out, ref=bool(inp)) + + def test_astype_float(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 70b24ace21..7800a30e41 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import re from typing import Optional, Pattern import pytest @@ -180,7 +180,7 @@ def callable_type_info_cases(): unary_func_type, [float_type], {}, - [r"Expected 0-th argument to be of type `bool`, but got `float64`."], + [r"Expected 1st argument to be of type `bool`, but got `float64`."], None, ), ( @@ -274,7 +274,7 @@ def callable_type_info_cases(): [int_type], {"bar": bool_type, "foo": bool_type}, [ - r"Expected 0-th argument to be of type `bool`, but got `int64`", + r"Expected 1st argument to be of type `bool`, but got `int64`", r"Expected argument `foo` to be of type `int64`, but got `bool`", r"Expected keyword argument `bar` to be of type `float64`, but got `bool`", ], @@ -299,7 +299,7 @@ def callable_type_info_cases(): [ts.TupleType(types=[float_type, field_type])], {}, [ - "Expected 0-th argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" + "Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" ], ts.VoidType(), ), @@ -308,7 +308,7 @@ def callable_type_info_cases(): [int_type], {}, [ - "Expected 0-th argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" + "Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" ], ts.VoidType(), ), @@ -761,15 +761,42 @@ def tuple_where_mix_dims( def test_astype_dtype(): - ADim = Dimension("ADim") - - def simple_astype(a: Field[[ADim], float64]): + def simple_astype(a: Field[[TDim], float64]): return astype(a, bool) parsed = FieldOperatorParser.apply_to_function(simple_astype) assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + ) + + +def test_astype_wrong_dtype(): + def simple_astype(a: Field[[TDim], float64]): + # we just use broadcast here, but anything with type function is fine + return astype(a, broadcast) + + with pytest.raises( + errors.DSLError, + match=r"Invalid call to `astype`. Second argument must be a scalar type, but got.", + ): + _ = FieldOperatorParser.apply_to_function(simple_astype) + + +def test_astype_wrong_value_type(): + def simple_astype(a: Field[[TDim], float64]): + # we just use a tuple here but anything that is not a field or scalar works + return astype((1, 2), bool) + + with pytest.raises(errors.DSLError) as exc_info: + _ = FieldOperatorParser.apply_to_function(simple_astype) + + assert ( + re.search( + "Expected 1st argument to be of type", + exc_info.value.__cause__.args[0], + ) + is not None ) From 90c4121472cba20c86b8311bc11dac7fd48b45ac Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Wed, 23 Aug 2023 21:58:06 +0200 Subject: [PATCH 05/67] Program with bound arguments (#1312) Allow for program signature to not include all required fields by the `field_operator` by setting keyword arguments ```python @gtx.program def program_bound_args(a: cases.IField, scalar: int32, bool_val: bool, out: cases.IField): fieldop_bound_args(a, scalar, bool_val, out=out) prog_bounds = program_bound_args.with_bound_args(scalar=scalar, bool_val=True) cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) ``` --- src/gt4py/next/ffront/decorator.py | 111 ++++++++++++++++++ .../iterator/transforms/constant_folding.py | 49 ++++++++ .../next/iterator/transforms/pass_manager.py | 45 ++++--- .../ffront_tests/test_arg_call_interface.py | 56 ++++++++- .../ffront_tests/test_execution.py | 20 ++++ .../transforms_tests/test_constant_folding.py | 47 ++++++++ 6 files changed, 310 insertions(+), 18 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/constant_folding.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 5b4d32b59e..12ab3955ab 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -52,6 +52,7 @@ from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -219,6 +220,40 @@ def with_backend(self, backend: ppi.ProgramExecutor) -> Program: def with_grid_type(self, grid_type: GridType) -> Program: return dataclasses.replace(self, grid_type=grid_type) + def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: + """ + Bind scalar, i.e. non field, program arguments. + + Example (pseudo-code): + + >>> import gt4py.next as gtx + >>> @gtx.program # doctest: +SKIP + ... def program(condition: bool, out: gtx.Field[[IDim], float]): # noqa: F821 + ... sample_field_operator(condition, out=out) # noqa: F821 + + Create a new program from `program` with the `condition` parameter set to `True`: + + >>> program_with_bound_arg = program.with_bound_args(condition=True) # doctest: +SKIP + + The resulting program is equivalent to + + >>> @gtx.program # doctest: +SKIP + ... def program(condition: bool, out: gtx.Field[[IDim], float]): # noqa: F821 + ... sample_field_operator(condition=True, out=out) # noqa: F821 + + and can be executed without passing `condition`. + + >>> program_with_bound_arg(out, offset_provider={}) # doctest: +SKIP + """ + for key in kwargs.keys(): + if all(key != param.id for param in self.past_node.params): + raise TypeError(f"Keyword argument `{key}` is not a valid program parameter.") + + return ProgramWithBoundArgs( + bound_args=kwargs, + **{field.name: getattr(self, field.name) for field in dataclasses.fields(self)}, + ) + @functools.cached_property def _all_closure_vars(self) -> dict[str, Any]: return _get_closure_vars_recursively(self.closure_vars) @@ -358,6 +393,82 @@ def _column_axis(self): return iter(scanops_per_axis.keys()).__next__() +@dataclasses.dataclass(frozen=True) +class ProgramWithBoundArgs(Program): + bound_args: dict[str, typing.Union[float, int, bool]] = None + + def _process_args(self, args: tuple, kwargs: dict): + type_ = self.past_node.type + new_type = ts_ffront.ProgramType( + definition=ts.FunctionType( + kw_only_args={ + k: v + for k, v in type_.definition.kw_only_args.items() + if k not in self.bound_args.keys() + }, + pos_only_args=type_.definition.pos_only_args, + pos_or_kw_args={ + k: v + for k, v in type_.definition.pos_or_kw_args.items() + if k not in self.bound_args.keys() + }, + returns=type_.definition.returns, + ) + ) + + arg_types = [type_translation.from_value(arg) for arg in args] + kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} + + try: + # This error is also catched using `accepts_args`, but we do it manually here to give + # a better error message. + for name in self.bound_args.keys(): + if name in kwargs: + raise ValueError(f"Parameter `{name}` already set as a bound argument.") + + type_info.accepts_args( + new_type, + with_args=arg_types, + with_kwargs=kwarg_types, + raise_exception=True, + ) + except ValueError as err: + bound_arg_names = ", ".join([f"`{bound_arg}`" for bound_arg in self.bound_args.keys()]) + raise TypeError( + f"Invalid argument types in call to program `{self.past_node.id}` with " + f"bound arguments {bound_arg_names}!" + ) from err + + full_args = [*args] + for index, param in enumerate(self.past_node.params): + if param.id in self.bound_args.keys(): + full_args.insert(index, self.bound_args[param.id]) + + return super()._process_args(tuple(full_args), kwargs) + + @functools.cached_property + def itir(self): + new_itir = super().itir + for new_clos in new_itir.closures: + for key in self.bound_args.keys(): + index = next( + index + for index, closure_input in enumerate(new_clos.inputs) + if closure_input.id == key + ) + new_clos.inputs.pop(index) + new_args = [ref(inp.id) for inp in new_clos.inputs] + params = [sym(inp.id) for inp in new_clos.inputs] + for value in self.bound_args.values(): + new_args.append(promote_to_const_iterator(literal_from_value(value))) + expr = itir.FunCall( + fun=new_clos.stencil, + args=new_args, + ) + new_clos.stencil = itir.Lambda(params=params, expr=expr) + return new_itir + + @typing.overload def program(definition: types.FunctionType) -> Program: ... diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py new file mode 100644 index 0000000000..cda422f30d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -0,0 +1,49 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.eve import NodeTranslator +from gt4py.next.iterator import embedded, ir, ir_makers as im + + +class ConstantFolding(NodeTranslator): + @classmethod + def apply(cls, node: ir.Node) -> ir.Node: + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall): + # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded + new_node = self.generic_visit(node) + + if ( + isinstance(new_node.fun, ir.SymRef) + and new_node.fun.id == "if_" + and isinstance(new_node.args[0], ir.Literal) + ): # `if_(True, true_branch, false_branch)` -> `true_branch` + if new_node.args[0].value == "True": + new_node = new_node.args[1] + else: + new_node = new_node.args[2] + + if ( + isinstance(new_node, ir.FunCall) + and isinstance(new_node.fun, ir.SymRef) + and len(new_node.args) > 0 + and all(isinstance(arg, ir.Literal) for arg in new_node.args) + ): # `1 + 1` -> `2` + if new_node.fun.id in ir.ARITHMETIC_BUILTINS: + fun = getattr(embedded, str(new_node.fun.id)) + arg_values = [getattr(embedded, str(arg.type))(arg.value) for arg in new_node.args] # type: ignore[attr-defined] # arg type already established in if condition + new_node = im.literal_from_value(fun(*arg_values)) + + return new_node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 50786bacd7..62251a3e43 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -18,6 +18,7 @@ from gt4py.next.iterator.transforms import simple_inline_heuristic from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps @@ -79,26 +80,36 @@ def apply_common_transforms( ir = PruneUnreferencedFundefs().visit(ir) ir = PropagateDeref.apply(ir) ir = NormalizeShifts().visit(ir) - if lift_mode != LiftMode.FORCE_TEMPORARIES: - for _ in range(10): - inlined = _inline_lifts(ir, lift_mode) - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift=(lift_mode == LiftMode.FORCE_INLINE), - ) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining lift and lambdas did not converge.") - else: - ir = InlineLambdas.apply( - ir, opcount_preserving=True, force_inline_lift=(lift_mode == LiftMode.FORCE_INLINE) + + for _ in range(10): + inlined = ir + + if lift_mode != LiftMode.FORCE_TEMPORARIES: + inlined = _inline_lifts(inlined, lift_mode) + + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift=(lift_mode == LiftMode.FORCE_INLINE), ) + inlined = ConstantFolding.apply(inlined) + # This pass is required to be in the loop such that when an `if_` call with tuple arguments + # is constant-folded the surrounding tuple_get calls can be removed. + inlined = CollapseTuple.apply(inlined) - if lift_mode == LiftMode.FORCE_INLINE: + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining lift and lambdas did not converge.") + + # Since `CollapseTuple` relies on the type inference which does not support returning tuples + # larger than the number of closure outputs as given by the unconditional collapse, we can + # only run the unconditional version here instead of in the loop above. + if unconditionally_collapse_tuples: ir = CollapseTuple.apply(ir, ignore_tuple_size=unconditionally_collapse_tuples) + + if lift_mode == LiftMode.FORCE_INLINE: ir = _inline_into_scan(ir) ir = NormalizeShifts().visit(ir) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 78c1b64e1b..ade410ef23 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import itertools +import re import typing import numpy as np @@ -20,8 +21,9 @@ from gt4py.next import errors from gt4py.next.common import Field +from gt4py.next.errors.exceptions import TypeError_ from gt4py.next.ffront.decorator import field_operator, program, scan_operator -from gt4py.next.ffront.fbuiltins import int32, int64 +from gt4py.next.ffront.fbuiltins import broadcast, int32, int64 from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases @@ -271,3 +273,55 @@ def testee_scan( @program def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float): testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2)) + + +@pytest.fixture +def bound_args_testee(): + @field_operator + def fieldop_bound_args() -> cases.IField: + return broadcast(0, (IDim,)) + + @program + def program_bound_args(arg1: bool, arg2: bool, out: cases.IField): + # for the test itself we don't care what happens here, but empty programs are not supported + fieldop_bound_args(out=out) + + return program_bound_args + + +def test_bind_invalid_arg(cartesian_case, bound_args_testee): + with pytest.raises( + TypeError, match="Keyword argument `inexistent_arg` is not a valid program parameter." + ): + bound_args_testee.with_bound_args(inexistent_arg=1) + + +def test_call_bound_program_with_wrong_args(cartesian_case, bound_args_testee): + program_with_bound_arg = bound_args_testee.with_bound_args(arg1=True) + out = cases.allocate(cartesian_case, bound_args_testee, "out")() + + with pytest.raises(TypeError) as exc_info: + program_with_bound_arg(out, offset_provider={}) + + assert ( + re.search( + "Function takes 2 positional arguments, but 1 were given.", + exc_info.value.__cause__.args[0], + ) + is not None + ) + + +def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_testee): + program_with_bound_arg = bound_args_testee.with_bound_args(arg2=True) + out = cases.allocate(cartesian_case, bound_args_testee, "out")() + + with pytest.raises(TypeError) as exc_info: + program_with_bound_arg(True, out, arg2=True, offset_provider={}) + + assert ( + re.search( + "Parameter `arg2` already set as a bound argument.", exc_info.value.__cause__.args[0] + ) + is not None + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 0268411c71..9f284f4041 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -741,6 +741,26 @@ def test_docstring(a: cases.IField): cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a) +def test_with_bound_args(cartesian_case): + @gtx.field_operator + def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: + if not condition: + scalar = 0 + return a + a + scalar + + @gtx.program + def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): + fieldop_bound_args(a, scalar, condition, out=out) + + a = cases.allocate(cartesian_case, program_bound_args, "a")() + scalar = int32(1) + ref = a.array() + a.array() + 1 + out = cases.allocate(cartesian_case, program_bound_args, "out")() + + prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) + cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) + + def test_domain(cartesian_case): @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py new file mode 100644 index 0000000000..5d052b1989 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding + + +def test_constant_folding_boolean(): + testee = im.not_(im.literal_from_value(True)) + expected = im.literal_from_value(False) + + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_folding_math_op(): + expected = im.literal_from_value(13) + testee = im.plus( + im.literal_from_value(4), + im.plus( + im.literal_from_value(7), im.minus(im.literal_from_value(7), im.literal_from_value(5)) + ), + ) + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_folding_if(): + expected = im.call("plus")("a", 2) + testee = im.if_( + im.literal_from_value(True), + im.plus(im.ref("a"), im.literal_from_value(2)), + im.minus(im.literal_from_value(9), im.literal_from_value(5)), + ) + actual = ConstantFolding.apply(testee) + assert actual == expected From c8dbffdcc876b141fa47abb1db1d90b501d82dfb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 24 Aug 2023 08:55:02 +0200 Subject: [PATCH 06/67] feature[next]: Collect shifts for all nodes in TraceShift pass (#1321) - Shifts are collected not only for the closure inputs, but for every iterator expression in the tree (including iterator arguments to lambdas). - The collected shifts are now represented as a set. This PR is a prerequisite for the temporary extraction heuristics. --- .../next/iterator/transforms/trace_shifts.py | 111 ++++++++++++---- .../transforms_tests/test_trace_shifts.py | 122 ++++++++++++------ 2 files changed, 168 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index ca4539c74f..b5697ca321 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -11,9 +11,9 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses import enum from collections.abc import Callable -from dataclasses import dataclass from typing import Any, Final, Iterable, Literal from gt4py.eve import NodeTranslator @@ -26,6 +26,30 @@ class Sentinel(enum.Enum): TYPE = object() +@dataclasses.dataclass(frozen=True) +class ShiftRecorder: + recorded_shifts: dict[int, set[tuple[ir.OffsetLiteral, ...]]] = dataclasses.field( + default_factory=dict + ) + + def register_node(self, inp: ir.Expr | ir.Sym) -> None: + self.recorded_shifts.setdefault(id(inp), set()) + + def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]) -> None: + self.recorded_shifts[id(inp)].add(offsets) + + +@dataclasses.dataclass(frozen=True) +class ForwardingShiftRecorder: + wrapped_tracer: Any + shift_recorder: ShiftRecorder + + def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]): + self.shift_recorder(inp, offsets) + # Forward shift to wrapped tracer such it can record the shifts of the parent nodes + self.wrapped_tracer.shift(offsets).deref() + + # for performance reasons (`isinstance` is slow otherwise) we don't use abc here class IteratorTracer: def deref(self): @@ -35,29 +59,27 @@ def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): raise NotImplementedError() -@dataclass(frozen=True) -class InputTracer(IteratorTracer): - inp: str - register_deref: Callable[[str, tuple[ir.OffsetLiteral, ...]], None] +@dataclasses.dataclass(frozen=True) +class IteratorArgTracer(IteratorTracer): + arg: ir.Expr | ir.Sym + shift_recorder: ShiftRecorder | ForwardingShiftRecorder offsets: tuple[ir.OffsetLiteral, ...] = () - lift_level: int = 0 def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): - return InputTracer( - inp=self.inp, - register_deref=self.register_deref, + return IteratorArgTracer( + arg=self.arg, + shift_recorder=self.shift_recorder, offsets=self.offsets + tuple(offsets), - lift_level=self.lift_level, ) def deref(self): - self.register_deref(self.inp, self.offsets) + self.shift_recorder(self.arg, self.offsets) return Sentinel.VALUE # This class is only needed because we currently allow conditionals on iterators. Since this is # not supported in the C++ backend it can likely be removed again in the future. -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class CombinedTracer(IteratorTracer): its: tuple[IteratorTracer, ...] @@ -98,13 +120,13 @@ def apply(arg): return apply -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class AppliedLift(IteratorTracer): stencil: Callable its: tuple[IteratorTracer, ...] def shift(self, offsets): - return AppliedLift(self.stencil, tuple(_shift(it) for it in self.its)) + return AppliedLift(self.stencil, tuple(_shift(*offsets)(it) for it in self.its)) def deref(self): return self.stencil(*self.its) @@ -211,7 +233,10 @@ def _tuple_get(index, tuple_val): } +@dataclasses.dataclass(frozen=True) class TraceShifts(NodeTranslator): + shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) + def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: return Sentinel.VALUE @@ -232,30 +257,62 @@ def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any: args = self.visit(node.args, ctx=ctx) return fun(*args) + def visit(self, node, **kwargs): + result = super().visit(node, **kwargs) + if isinstance(result, IteratorTracer): + assert isinstance(node, (ir.Sym, ir.Expr)) + + self.shift_recorder.register_node(node) + result = IteratorArgTracer( + arg=node, shift_recorder=ForwardingShiftRecorder(result, self.shift_recorder) + ) + return result + def visit_Lambda(self, node: ir.Lambda, *, ctx: dict[str, Any]) -> Callable: def fun(*args): + new_args = [] + for param, arg in zip(node.params, args, strict=True): + if isinstance(arg, IteratorTracer): + self.shift_recorder.register_node(param) + new_args.append( + IteratorArgTracer( + arg=param, + shift_recorder=ForwardingShiftRecorder(arg, self.shift_recorder), + ) + ) + else: + new_args.append(arg) + return self.visit( - node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)} + node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, new_args, strict=True)} ) return fun - def visit_StencilClosure( - self, node: ir.StencilClosure, *, shifts: dict[str, list[tuple[ir.OffsetLiteral, ...]]] - ): - def register_deref(inp: str, offsets: tuple[ir.OffsetLiteral, ...]): - shifts[inp].append(offsets) - + def visit_StencilClosure(self, node: ir.StencilClosure): tracers = [] for inp in node.inputs: - shifts.setdefault(inp.id, []) - tracers.append(InputTracer(inp=inp.id, register_deref=register_deref)) + self.shift_recorder.register_node(inp) + tracers.append(IteratorArgTracer(arg=inp, shift_recorder=self.shift_recorder)) result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) @classmethod - def apply(cls, node: ir.StencilClosure) -> dict[str, list[tuple[ir.OffsetLiteral, ...]]]: - shifts = dict[str, list[tuple[ir.OffsetLiteral, ...]]]() - cls().visit(node, shifts=shifts) - return shifts + def apply( + cls, node: ir.StencilClosure, *, inputs_only=True + ) -> ( + dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]] + ): + instance = cls() + instance.visit(node) + + recorded_shifts = instance.shift_recorder.recorded_shifts + + if inputs_only: + inputs_shifts = {} + for inp in node.inputs: + inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)] + return inputs_shifts + + return recorded_shifts diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index a15a7a29cd..0e2fa22f05 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -23,10 +23,9 @@ def test_trivial(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp": [()]} + expected = {"inp": {()}} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -51,10 +50,9 @@ def test_shift(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp": [(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))]} + expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -84,10 +82,9 @@ def test_lift(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp": [(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))]} + expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -105,16 +102,15 @@ def test_neighbors(): domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) expected = { - "inp": [ + "inp": { ( ir.OffsetLiteral(value="O"), ALL_NEIGHBORS, ) - ] + } } - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -134,10 +130,9 @@ def test_reduce(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp": [()]} + expected = {"inp": {()}} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -150,10 +145,9 @@ def test_shifted_literal(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp": []} + expected = {"inp": set()} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -165,13 +159,52 @@ def test_tuple_get(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp1": [], "inp2": [()]} # never derefed # once derefed + expected = {"inp1": set(), "inp2": {()}} # never derefed # once derefed - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected +def test_trace_non_closure_input_arg(): + x, y = im.sym("x"), im.sym("y") + testee = ir.StencilClosure( + # λ(x) → (λ(y) → ·⟪Iₒ, 1ₒ⟫(y))(⟪Iₒ, 2ₒ⟫(x)) + stencil=im.lambda_(x)( + im.call(im.lambda_(y)(im.deref(im.shift("I", 1)("y"))))(im.shift("I", 2)("x")) + ), + inputs=[ir.SymRef(id="inp")], + output=ir.SymRef(id="out"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + ) + + actual = TraceShifts.apply(testee, inputs_only=False) + + assert actual[id(x)] == { + ( + ir.OffsetLiteral(value="I"), + ir.OffsetLiteral(value=2), + ir.OffsetLiteral(value="I"), + ir.OffsetLiteral(value=1), + ) + } + assert actual[id(y)] == {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} + + +def test_inner_iterator(): + inner_shift = im.shift("I", 1)("x") + testee = ir.StencilClosure( + # λ(x) → ·⟪Iₒ, 1ₒ⟫(⟪Iₒ, 1ₒ⟫(x)) + stencil=im.lambda_("x")(im.deref(im.shift("I", 1)(inner_shift))), + inputs=[ir.SymRef(id="inp")], + output=ir.SymRef(id="out"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + ) + expected = {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} + + actual = TraceShifts.apply(testee, inputs_only=False) + assert actual[id(inner_shift)] == expected + + def test_tuple_get_on_closure_input(): testee = ir.StencilClosure( # λ(x) → (·⟪Iₒ, 1ₒ⟫(x))[0] @@ -180,10 +213,9 @@ def test_tuple_get_on_closure_input(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"inp": [(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))]} + expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -204,10 +236,9 @@ def test_if_tuple_branch_broadcasting(): output=ir.SymRef(id="out"), domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) - expected = {"cond": [()], "inp": [()]} + expected = {"cond": {()}, "inp": {()}} - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -226,8 +257,8 @@ def test_if_of_iterators(): domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) expected = { - "cond": [()], - "inp": [ + "cond": {()}, + "inp": { ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -240,11 +271,10 @@ def test_if_of_iterators(): ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1), ), - ], + }, } - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected @@ -271,8 +301,8 @@ def test_if_of_tuples_of_iterators(): domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), ) expected = { - "cond": [()], - "inp": [ + "cond": {()}, + "inp": { ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -285,9 +315,25 @@ def test_if_of_tuples_of_iterators(): ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1), ), - ], + }, } - actual = dict() - TraceShifts().visit(testee, shifts=actual) + actual = TraceShifts.apply(testee) assert actual == expected + + +def test_non_derefed_iterator(): + """ + Test that even if an iterator is not derefed the resulting dict has an (empty) entry for it. + """ + non_derefed_it = im.shift("I", 1)("x") + testee = ir.StencilClosure( + # λ(x) → (λ(non_derefed_it) → ·x)(⟪Iₒ, 1ₒ⟫(x)) + stencil=im.lambda_("x")(im.let("non_derefed_it", non_derefed_it)(im.deref("x"))), + inputs=[ir.SymRef(id="inp")], + output=ir.SymRef(id="out"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + ) + + actual = TraceShifts.apply(testee, inputs_only=False) + assert actual[id(non_derefed_it)] == set() From af5aa118ca3a31a9a587ce29c581b3d02cce9b46 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 31 Aug 2023 16:36:26 +0200 Subject: [PATCH 07/67] feat[next]: Add `FieldBuiltinFuncRegistry` mixin (#1330) Adds FieldBuiltinFuncRegistry to allow Field subclasses to register their own builtins Co-authored-by: Hannes Vogt --- src/gt4py/next/common.py | 24 +++++++++++++++- src/gt4py/next/embedded/nd_array_field.py | 35 ++--------------------- src/gt4py/next/ffront/fbuiltins.py | 4 --- 3 files changed, 26 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index e06f9c54b1..866b2aadb7 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -15,13 +15,14 @@ from __future__ import annotations import abc +import collections import dataclasses import enum import functools import sys from collections.abc import Sequence, Set from types import EllipsisType -from typing import TypeGuard, overload +from typing import ChainMap, TypeGuard, overload import numpy as np import numpy.typing as npt @@ -481,3 +482,24 @@ def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: return isinstance(index, Sequence) and all( is_named_range(idx) or is_named_index(idx) for idx in index ) + + +class FieldBuiltinFuncRegistry: + _builtin_func_map: ChainMap[fbuiltins.BuiltInFunction, Callable] = collections.ChainMap() + + def __init_subclass__(cls, **kwargs): + # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) + cls._builtin_func_map = cls._builtin_func_map.new_child() + + @classmethod + def register_builtin_func( + cls, /, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None + ) -> Any: + assert op not in cls._builtin_func_map + if op_func is None: # when used as a decorator + return functools.partial(cls.register_builtin_func, op) + return cls._builtin_func_map.setdefault(op, op_func) + + @classmethod + def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]: + return cls._builtin_func_map.get(func, NotImplemented) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9813efdd22..ddef77bb78 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,17 +15,17 @@ from __future__ import annotations import dataclasses -import functools import itertools from collections.abc import Callable, Sequence from types import EllipsisType, ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast, overload +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs from gt4py.next import common +from gt4py.next.common import FieldBuiltinFuncRegistry from gt4py.next.ffront import fbuiltins @@ -82,7 +82,7 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): +class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldBuiltinFuncRegistry): """ Shared field implementation for NumPy-like fields. @@ -100,35 +100,6 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): ModuleType ] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol - _builtin_func_map: ClassVar[dict[fbuiltins.BuiltInFunction, Callable]] = {} - - @classmethod - def __gt_builtin_func__(cls, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: - return cls._builtin_func_map.get(func, NotImplemented) - - @overload - @classmethod - def register_builtin_func( - cls, op: fbuiltins.BuiltInFunction[_R, _P], op_func: None - ) -> functools.partial[Callable[_P, _R]]: - ... - - @overload - @classmethod - def register_builtin_func( - cls, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Callable[_P, _R] - ) -> Callable[_P, _R]: - ... - - @classmethod - def register_builtin_func( - cls, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None - ) -> Callable[_P, _R] | functools.partial[Callable[_P, _R]]: - assert op not in cls._builtin_func_map - if op_func is None: # when used as a decorator - return functools.partial(cls.register_builtin_func, op) # type: ignore[arg-type] - return cls._builtin_func_map.setdefault(op, op_func) - @property def domain(self) -> common.Domain: return self._domain diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b6831b35df..ba027be13c 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - import dataclasses import inspect from builtins import bool, float, int, tuple @@ -49,7 +48,6 @@ TYPE_ALIAS_NAMES = ["IndexType"] - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -205,7 +203,6 @@ def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: "trunc", ] - UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES = ["isfinite", "isinf", "isnan"] @@ -224,7 +221,6 @@ def impl(value: Field | gt4py_defs.ScalarT, /) -> Field | gt4py_defs.ScalarT: ): _make_unary_math_builtin(f) - BINARY_MATH_NUMBER_BUILTIN_NAMES = ["minimum", "maximum", "fmod", "power"] From 9d87bbd8add1a0d741b4481886f657df84c41e5b Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 5 Sep 2023 09:40:36 +0200 Subject: [PATCH 08/67] [dace] Enable tests with type inference (#1331) Enable some ITIR tests which were disabled because argument types were not propagated. Now possible to run, after improvements to type inference. --- .../feature_tests/ffront_tests/test_gt4py_builtins.py | 5 ----- .../feature_tests/iterator_tests/test_builtins.py | 8 -------- .../feature_tests/iterator_tests/test_implicit_fencil.py | 4 ---- .../iterator_tests/test_column_stencil.py | 4 ---- 4 files changed, 21 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 26f01ca813..5f19311a32 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,8 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") if unstructured_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: pytest.xfail("`maxover` broken in gtfn, see #1289.") @@ -65,9 +63,6 @@ def testee(edge_f: cases.EField) -> cases.VField: def test_minover_execution(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 13fcf3b87f..673a989122 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -171,10 +171,6 @@ def arithmetic_and_logical_test_data(): @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) inps = asfield(*asarray(*inputs)) out = asfield((np.zeros_like(*asarray(expected))))[0] @@ -207,10 +203,6 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) if builtin_name == "gamma": # numpy has no gamma function diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 37ac4623fd..2f7808b30e 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -59,10 +59,6 @@ def test_single_argument(program_processor, dom): def test_2_arguments(program_processor, dom): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) @fundef def fun(inp0, inp1): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 6c58ded3a9..5970b9a2a9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -292,10 +292,6 @@ def sum_shifted_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) From 8d91a7b047f5d3af57ac3c1407a412e69f6cea89 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 11:03:27 +0200 Subject: [PATCH 09/67] refactor[next] Prepare new Field for itir.embedded (#1329) - improve TypeAliases - add `domain` and `unit_range` constructors - extract domain slicing utils to `next.embedded.common` - introduce `MutableField` - add some missing operators to `Field` --------- Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- docs/user/cartesian/Makefile | 15 +- docs/user/cartesian/arrays.rst | 6 +- src/gt4py/_core/definitions.py | 54 ++- src/gt4py/next/__init__.py | 3 + src/gt4py/next/common.py | 326 ++++++++++++++---- src/gt4py/next/embedded/common.py | 127 +++++++ src/gt4py/next/embedded/exceptions.py | 38 ++ src/gt4py/next/embedded/nd_array_field.py | 272 +++++++-------- src/gt4py/next/errors/exceptions.py | 12 +- src/gt4py/next/ffront/fbuiltins.py | 1 + src/gt4py/next/utils.py | 9 +- .../ffront_tests/test_foast_pretty_printer.py | 2 +- .../unit_tests/embedded_tests/test_common.py | 137 ++++++++ .../embedded_tests/test_nd_array_field.py | 197 ++++++++--- tests/next_tests/unit_tests/test_common.py | 125 ++++--- 15 files changed, 1002 insertions(+), 322 deletions(-) create mode 100644 src/gt4py/next/embedded/common.py create mode 100644 src/gt4py/next/embedded/exceptions.py create mode 100644 tests/next_tests/unit_tests/embedded_tests/test_common.py diff --git a/docs/user/cartesian/Makefile b/docs/user/cartesian/Makefile index 091bc3b8d2..13e692b96d 100644 --- a/docs/user/cartesian/Makefile +++ b/docs/user/cartesian/Makefile @@ -2,12 +2,13 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -SRCDIR = ../../../src/gt4py -AUTODOCDIR = _source -BUILDDIR = _build +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +SRCDIR = ../../../src/gt4py +SPHINX_APIDOC_OPTS = --private # private modules for gt4py._core +AUTODOCDIR = _source +BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) @@ -55,7 +56,7 @@ clean: autodoc: @echo @echo "Running sphinx-apidoc..." - sphinx-apidoc ${SPHINX_OPTS} -o ${AUTODOCDIR} ${SRCDIR} + sphinx-apidoc ${SPHINX_APIDOC_OPTS} -o ${AUTODOCDIR} ${SRCDIR} @echo @echo "sphinx-apidoc finished. The generated autodocs are in $(AUTODOCDIR)." diff --git a/docs/user/cartesian/arrays.rst b/docs/user/cartesian/arrays.rst index 6788e2757f..6ef7c6e5c1 100644 --- a/docs/user/cartesian/arrays.rst +++ b/docs/user/cartesian/arrays.rst @@ -39,6 +39,8 @@ Internally, gt4py uses the utilities :code:`gt4py.utils.as_numpy` and :code:`gt4 buffers. GT4Py developers are advised to always use those utilities as to guarantee support across gt4py as the supported interfaces are extended. +.. _cartesian-arrays-dimension-mapping: + Dimension Mapping ^^^^^^^^^^^^^^^^^ @@ -56,6 +58,8 @@ which implements this lookup. Note: Support for xarray can be added manually by the user by means of the mechanism described `here `_. +.. _cartesian-arrays-default-origin: + Default Origin ^^^^^^^^^^^^^^ @@ -180,4 +184,4 @@ Additionally, these **optional** keyword-only parameters are accepted: determine the default layout for the storage. Currently supported will be :code:`"I"`, :code:`"J"`, :code:`"K"` and additional dimensions as string representations of integers, starting at :code:`"0"`. (This information is not retained in the resulting array, and needs to be specified instead - with the :code:`__gt_dims__` interface. ) \ No newline at end of file + with the :code:`__gt_dims__` interface. ) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 2546ae3e4e..059ba6c24c 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -213,18 +213,13 @@ class DType(Generic[ScalarT]): """ scalar_type: Type[ScalarT] - tensor_shape: TensorShape + tensor_shape: TensorShape = dataclasses.field(default=()) - def __init__( - self, scalar_type: Type[ScalarT], tensor_shape: Sequence[IntegralScalar] = () - ) -> None: - if not isinstance(scalar_type, type): - raise TypeError(f"Invalid scalar type '{scalar_type}'") - if not is_valid_tensor_shape(tensor_shape): - raise TypeError(f"Invalid tensor shape '{tensor_shape}'") - - object.__setattr__(self, "scalar_type", scalar_type) - object.__setattr__(self, "tensor_shape", tensor_shape) + def __post_init__(self) -> None: + if not isinstance(self.scalar_type, type): + raise TypeError(f"Invalid scalar type '{self.scalar_type}'") + if not is_valid_tensor_shape(self.tensor_shape): + raise TypeError(f"Invalid tensor shape '{self.tensor_shape}'") @functools.cached_property def kind(self) -> DTypeKind: @@ -251,6 +246,16 @@ def lanes(self) -> int: def subndim(self) -> int: return len(self.tensor_shape) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, DType) + and self.scalar_type == other.scalar_type + and self.tensor_shape == other.tensor_shape + ) + + def __hash__(self) -> int: + return hash((self.scalar_type, self.tensor_shape)) + @dataclasses.dataclass(frozen=True) class IntegerDType(DType[IntegralT]): @@ -322,6 +327,11 @@ class Float64DType(FloatingDType[float64]): scalar_type: Final[Type[float64]] = dataclasses.field(default=float64, init=False) +@dataclasses.dataclass(frozen=True) +class BoolDType(DType[bool_]): + scalar_type: Final[Type[bool_]] = dataclasses.field(default=bool_, init=False) + + DTypeLike = Union[DType, npt.DTypeLike] @@ -332,11 +342,29 @@ def dtype(dtype_like: DTypeLike) -> DType: # -- Custom protocols -- class GTDimsInterface(Protocol): - __gt_dims__: Tuple[str, ...] + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming the buffer dimensions. + + In `gt4py.cartesian` the allowed values are `"I"`, `"J"` and `"K"` with the established semantics. + + See :ref:`cartesian-arrays-dimension-mapping` for details. + """ + + @property + def __gt_dims__(self) -> Tuple[str, ...]: + ... class GTOriginInterface(Protocol): - __gt_origin__: Tuple[int, ...] + """ + A `GTOriginInterface` is an object providing `__gt_origin__`, describing the origin of a buffer. + + See :ref:`cartesian-arrays-default-origin` for details. + """ + + @property + def __gt_origin__(self) -> Tuple[int, ...]: + ... # -- Device representation -- diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index b4d1fc0c09..cc35899668 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,6 +25,9 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType +from .embedded import ( # Just for registering field implementations + nd_array_field as _nd_array_field, +) from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 866b2aadb7..b85239cd0a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,9 @@ import enum import functools import sys -from collections.abc import Sequence, Set -from types import EllipsisType -from typing import ChainMap, TypeGuard, overload +import types +from collections.abc import Mapping, Sequence, Set +from typing import overload import numpy as np import numpy.typing as npt @@ -37,16 +37,18 @@ ParamSpec, Protocol, TypeAlias, + TypeGuard, TypeVar, + cast, extended_runtime_checkable, - final, runtime_checkable, ) from gt4py.eve.type_definitions import StrEnum -DimT = TypeVar("DimT", bound="Dimension") -DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) +DimsT = TypeVar( + "DimsT", covariant=True +) # bound to `Sequence[Dimension]` if instance of Dimension would be a type class Infinity(int): @@ -66,7 +68,7 @@ class DimensionKind(StrEnum): LOCAL = "local" def __str__(self): - return f"{type(self).__name__}.{self.name}" + return self.value @dataclasses.dataclass(frozen=True) @@ -75,7 +77,7 @@ class Dimension: kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL) def __str__(self): - return f'Dimension(value="{self.value}", kind={self.kind})' + return f"{self.value}[{self.kind}]" @dataclasses.dataclass(frozen=True) @@ -136,36 +138,139 @@ def __and__(self, other: Set[Any]) -> UnitRange: else: raise NotImplementedError("Can only find the intersection between UnitRange instances.") + def __str__(self) -> str: + return f"({self.start}:{self.stop})" + + +RangeLike: TypeAlias = UnitRange | range | tuple[int, int] + + +def unit_range(r: RangeLike) -> UnitRange: + if isinstance(r, UnitRange): + return r + if isinstance(r, range): + if r.step != 1: + raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") + return UnitRange(r.start, r.stop) + if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): + return UnitRange(r[0], r[1]) + raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.") -DomainRange: TypeAlias = UnitRange | int + +IntIndex: TypeAlias = int | core_defs.IntegralScalar +NamedIndex: TypeAlias = tuple[Dimension, IntIndex] NamedRange: TypeAlias = tuple[Dimension, UnitRange] -NamedIndex: TypeAlias = tuple[Dimension, int] -DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] -FieldSlice: TypeAlias = ( - DomainSlice - | tuple[slice | int | EllipsisType, ...] - | slice - | int - | EllipsisType - | NamedRange - | NamedIndex -) +RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement +AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] +RelativeIndexSequence: TypeAlias = tuple[ + slice | IntIndex | types.EllipsisType, ... +] # is a tuple but called Sequence for symmetry +AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence +AnyIndexSpec: TypeAlias = AnyIndexElement | AnyIndexSequence + + +def is_int_index(p: Any) -> TypeGuard[IntIndex]: + # should be replaced by isinstance(p, IntIndex), but mypy complains with + # `Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" [arg-type]` + return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) + + +def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) + and len(v) == 2 + and isinstance(v[0], Dimension) + and isinstance(v[1], UnitRange) + ) -@dataclasses.dataclass(frozen=True) +def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) + ) + + +def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: + return ( + is_int_index(v) + or is_named_range(v) + or is_named_index(v) + or isinstance(v, slice) + or v is Ellipsis + ) + + +def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: + return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) + + +def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: + return isinstance(v, tuple) and all( + isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v + ) + + +def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: + # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` + return cast( + AnyIndexSequence, + (index,) if is_any_index_element(index) else index, + ) + + +def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: + return (v[0], unit_range(v[1])) + + +@dataclasses.dataclass(frozen=True, init=False) class Domain(Sequence[NamedRange]): + """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" + dims: tuple[Dimension, ...] ranges: tuple[UnitRange, ...] - def __post_init__(self): + def __init__( + self, + *args: NamedRange, + dims: Optional[tuple[Dimension, ...]] = None, + ranges: Optional[tuple[UnitRange, ...]] = None, + ) -> None: + if dims is not None or ranges is not None: + if dims is None and ranges is None: + raise ValueError("Either both none of `dims` and `ranges` must be specified.") + if len(args) > 0: + raise ValueError( + "No extra `args` allowed when constructing fomr `dims` and `ranges`." + ) + + assert dims is not None and ranges is not None # for mypy + if not all(isinstance(dim, Dimension) for dim in dims): + raise ValueError( + f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + ) + if not all(isinstance(rng, UnitRange) for rng in ranges): + raise ValueError( + f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + ) + if len(dims) != len(ranges): + raise ValueError( + f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." + ) + + object.__setattr__(self, "dims", dims) + object.__setattr__(self, "ranges", ranges) + else: + if not all(is_named_range(arg) for arg in args): + raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") + dims, ranges = zip(*args) if args else ((), ()) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) + if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") - if len(self.dims) != len(self.ranges): - raise ValueError( - f"Number of provided dimensions ({len(self.dims)}) does not match number of provided ranges ({len(self.ranges)})." - ) - def __len__(self) -> int: return len(self.ranges) @@ -174,7 +279,7 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> "Domain": + def __getitem__(self, index: slice) -> Domain: ... @overload @@ -187,7 +292,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] - return Domain(dims_slice, ranges_slice) + return Domain(dims=dims_slice, ranges=ranges_slice) elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) @@ -197,7 +302,21 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") - def __and__(self, other: "Domain") -> "Domain": + def __and__(self, other: Domain) -> Domain: + """ + Intersect `Domain`s, missing `Dimension`s are considered infinite. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) + + >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) + """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( rng1 & rng2 @@ -206,15 +325,49 @@ def __and__(self, other: "Domain") -> "Domain": _broadcast_ranges(broadcast_dims, other.dims, other.ranges), ) ) - return Domain(broadcast_dims, intersected_ranges) + return Domain(dims=broadcast_dims, ranges=intersected_ranges) + + def __str__(self) -> str: + return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + + +DomainLike: TypeAlias = ( + Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +) # `Domain` is `Sequence[NamedRange]` and therefore a subset + + +def domain(domain_like: DomainLike) -> Domain: + """ + Construct `Domain` from `DomainLike` object. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> domain(((I, (2, 4)), (J, (3, 5)))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + + >>> domain({I: (2, 4), J: (3, 5)}) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + """ + if isinstance(domain_like, Domain): + return domain_like + if isinstance(domain_like, Sequence): + return Domain(*tuple(named_range(d) for d in domain_like)) + if isinstance(domain_like, Mapping): + return Domain( + dims=tuple(domain_like.keys()), + ranges=tuple(unit_range(r) for r in domain_like.values()), + ) + raise ValueError(f"`{domain_like}` is not `DomainLike`.") def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange(Infinity.negative(), Infinity.positive()) - for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims ) @@ -230,8 +383,22 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... +class NextGTDimsInterface(Protocol): + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. + + The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`, + where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` . + """ + + # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way + @property + def __gt_dims__(self) -> tuple[Dimension, ...]: + ... + + @extended_runtime_checkable -class Field(Protocol[DimsT, core_defs.ScalarT]): +class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property @@ -242,24 +409,19 @@ def domain(self) -> Domain: def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... - @property - def value_type(self) -> type[core_defs.ScalarT]: - ... - @property def ndarray(self) -> core_defs.NDArrayObject: ... def __str__(self) -> str: - codomain = self.value_type.__name__ - return f"⟨{self.domain!s} → {codomain}⟩" + return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @@ -268,7 +430,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -279,6 +441,10 @@ def __abs__(self) -> Field: def __neg__(self) -> Field: ... + @abc.abstractmethod + def __invert__(self) -> Field: + """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @@ -323,23 +489,44 @@ def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... + @abc.abstractmethod + def __and__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __or__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __xor__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + def is_field( v: Any, -) -> TypeGuard[Field]: # this function is introduced to localize the `type: ignore`` +) -> TypeGuard[Field]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed return isinstance(v, Field) # type: ignore[misc] # we use extended_runtime_checkable -class FieldABC(Field[DimsT, core_defs.ScalarT]): - """Abstract base class for implementations of the :class:`Field` protocol.""" +@extended_runtime_checkable +class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): + @abc.abstractmethod + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: + ... - @final - def __setattr__(self, key, value) -> None: - raise TypeError("Immutable type") - @final - def __setitem__(self, key, value) -> None: - raise TypeError("Immutable type") +def is_mutable_field( + v: Field, +) -> TypeGuard[MutableField]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed + return isinstance(v, MutableField) # type: ignore[misc] # we use extended_runtime_checkable @functools.singledispatch @@ -347,8 +534,8 @@ def field( definition: Any, /, *, - domain: Optional[Any] = None, # TODO(havogt): provide domain_like to Domain conversion - value_type: Optional[type] = None, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -470,26 +657,27 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: return topologically_sorted_list -def is_named_range(v: Any) -> TypeGuard[NamedRange]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], UnitRange) - - -def is_named_index(v: Any) -> TypeGuard[NamedIndex]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], int) - - -def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: - return isinstance(index, Sequence) and all( - is_named_range(idx) or is_named_index(idx) for idx in index - ) +class FieldBuiltinFuncRegistry: + """ + Mixin for adding `fbuiltins` registry to a `Field`. + Subclasses of a `Field` with `FieldBuiltinFuncRegistry` get their own registry, + dispatching (via ChainMap) to its parent's registries. + """ -class FieldBuiltinFuncRegistry: - _builtin_func_map: ChainMap[fbuiltins.BuiltInFunction, Callable] = collections.ChainMap() + _builtin_func_map: collections.ChainMap[ + fbuiltins.BuiltInFunction, Callable + ] = collections.ChainMap() def __init_subclass__(cls, **kwargs): - # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) - cls._builtin_func_map = cls._builtin_func_map.new_child() + cls._builtin_func_map = collections.ChainMap( + {}, # New empty `dict`` for new registrations on this class + *[ + c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order + for c in cls.__mro__ + if "_builtin_func_map" in c.__dict__ + ], + ) @classmethod def register_builtin_func( diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py new file mode 100644 index 0000000000..37ba4954f3 --- /dev/null +++ b/src/gt4py/next/embedded/common.py @@ -0,0 +1,127 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Sequence, cast + +from gt4py.next import common +from gt4py.next.embedded import exceptions as embedded_exceptions + + +def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: + index_sequence = common.as_any_index_sequence(index) + + if common.is_absolute_index_sequence(index_sequence): + return _absolute_sub_domain(domain, index_sequence) + + if common.is_relative_index_sequence(index_sequence): + return _relative_sub_domain(domain, index_sequence) + + raise IndexError(f"Unsupported index type: {index}") + + +def _relative_sub_domain( + domain: common.Domain, index: common.RelativeIndexSequence +) -> common.Domain: + named_ranges: list[common.NamedRange] = [] + + expanded = _expand_ellipsis(index, len(domain)) + if len(domain) < len(expanded): + raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + expanded += (slice(None),) * (len(domain) - len(expanded)) + for (dim, rng), idx in zip(domain, expanded, strict=True): + if isinstance(idx, slice): + try: + sliced = _slice_range(rng, idx) + named_ranges.append((dim, sliced)) + except IndexError: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) + else: + # not in new domain + assert common.is_int_index(idx) + new_index = (rng.start if idx >= 0 else rng.stop) + idx + if new_index < rng.start or new_index >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) + + return common.Domain(*named_ranges) + + +def _absolute_sub_domain( + domain: common.Domain, index: common.AbsoluteIndexSequence +) -> common.Domain: + named_ranges: list[common.NamedRange] = [] + for i, (dim, rng) in enumerate(domain): + if (pos := _find_index_of_dim(dim, index)) is not None: + named_idx = index[pos] + idx = named_idx[1] + if isinstance(idx, common.UnitRange): + if not idx <= rng: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + + named_ranges.append((dim, idx)) + else: + # not in new domain + assert common.is_int_index(idx) + if idx < rng.start or idx >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + else: + # dimension not mentioned in slice + named_ranges.append((dim, domain.ranges[i])) + + return common.Domain(*named_ranges) + + +def _expand_ellipsis( + indices: common.RelativeIndexSequence, target_size: int +) -> tuple[common.IntIndex | slice, ...]: + if Ellipsis in indices: + idx = indices.index(Ellipsis) + indices = ( + indices[:idx] + (slice(None),) * (target_size - (len(indices) - 1)) + indices[idx + 1 :] + ) + return cast(tuple[common.IntIndex | slice, ...], indices) # mypy leave me alone and trust me! + + +def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: + if slice_obj == slice(None): + return common.UnitRange(input_range.start, input_range.stop) + + start = ( + input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop + ) + (slice_obj.start or 0) + stop = ( + input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop + ) + (slice_obj.stop or len(input_range)) + + if start < input_range.start or stop > input_range.stop: + raise IndexError("Slice out of range (no clipping following array API standard).") + + return common.UnitRange(start, stop) + + +def _find_index_of_dim( + dim: common.Dimension, + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], +) -> Optional[int]: + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i + return None diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py new file mode 100644 index 0000000000..393123db36 --- /dev/null +++ b/src/gt4py/next/embedded/exceptions.py @@ -0,0 +1,38 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next import common +from gt4py.next.errors import exceptions as gt4py_exceptions + + +class IndexOutOfBounds(gt4py_exceptions.GT4PyError): + domain: common.Domain + indices: common.AnyIndexSpec + index: common.AnyIndexElement + dim: common.Dimension + + def __init__( + self, + domain: common.Domain, + indices: common.AnyIndexSpec, + index: common.AnyIndexElement, + dim: common.Dimension, + ): + super().__init__( + f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." + ) + self.domain = domain + self.indices = indices + self.index = index + self.dim = dim diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ddef77bb78..fcaa09e7eb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,17 +15,16 @@ from __future__ import annotations import dataclasses -import itertools from collections.abc import Callable, Sequence -from types import EllipsisType, ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast +from types import ModuleType +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import FieldBuiltinFuncRegistry +from gt4py.next.embedded import common as embedded_common from gt4py.next.ffront import fbuiltins @@ -56,7 +55,7 @@ def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_nam def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: xp = a.__class__.array_ns op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # isinstance(b, common.Field): + if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): if not a.domain == b.domain: domain_intersection = a.domain & b.domain a_broadcasted = _broadcast(a, domain_intersection.dims) @@ -82,7 +81,9 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldBuiltinFuncRegistry): +class _BaseNdArrayField( + common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry +): """ Shared field implementation for NumPy-like fields. @@ -94,7 +95,6 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldB _domain: common.Domain _ndarray: core_defs.NDArrayObject - _value_type: type[core_defs.ScalarT] array_ns: ClassVar[ ModuleType @@ -104,13 +104,28 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldB def domain(self) -> common.Domain: return self._domain + @property + def shape(self) -> tuple[int, ...]: + return self._ndarray.shape + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return self._domain.dims + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple(-r.start for _, r in self._domain) + @property def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray + def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: + return np.asarray(self._ndarray, dtype) + @property - def value_type(self) -> type[core_defs.ScalarT]: - return self._value_type + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + return core_defs.dtype(self._ndarray.dtype.type) @classmethod def from_array( @@ -119,38 +134,52 @@ def from_array( | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike /, *, - domain: common.Domain, - value_type: Optional[type] = None, + domain: common.DomainLike, + dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike ) -> _BaseNdArrayField: + domain = common.domain(domain) xp = cls.array_ns - dtype = None - if value_type is not None: - dtype = xp.dtype(value_type) - array = xp.asarray(data, dtype=dtype) - value_type = array.dtype.type # TODO add support for Dimensions as value_type + xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) + array = xp.asarray(data, dtype=xp_dtype) + + if dtype_like is not None: + assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) - assert all(isinstance(d, common.Dimension) for d, r in domain), domain + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim assert all( - len(nr[1]) == s or (s == 1 and nr[1] == common.UnitRange.infinity()) - for nr, s in zip(domain, array.shape) + len(r) == s or (s == 1 and r == common.UnitRange.infinity()) + for r, s in zip(domain.ranges, array.shape) ) - assert value_type is not None # for mypy - return cls(domain, array, value_type) + return cls(domain, array) def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + new_domain, buffer_slice = self._slice(index) + + new_buffer = self.ndarray[buffer_slice] + if len(new_domain) == 0: + assert core_defs.is_scalar_type(new_buffer) + return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + else: + return self.__class__.from_array(new_buffer, domain=new_domain) + + __getitem__ = restrict + __call__ = None # type: ignore[assignment] # TODO: remap __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") @@ -165,78 +194,51 @@ def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") - def __getitem__(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: - if ( - not isinstance(index, tuple) - and not common.is_domain_slice(index) - or common.is_named_index(index) - or common.is_named_range(index) - ): - index = cast(common.FieldSlice, (index,)) + __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - if common.is_domain_slice(index): - return self._getitem_absolute_slice(index) + def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( + self, other + ) + raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") - assert isinstance(index, tuple) - if all(isinstance(idx, (slice, int)) or idx is Ellipsis for idx in index): - return self._getitem_relative_slice(index) + __rand__ = __and__ - raise IndexError(f"Unsupported index type: {index}") + def __or__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") - restrict = ( - __getitem__ # type:ignore[assignment] # TODO(havogt) I don't see the problem that mypy has - ) + __ror__ = __or__ - def _getitem_absolute_slice( - self, index: common.DomainSlice - ) -> common.Field | core_defs.ScalarT: - slices = _get_slices_from_domain_slice(self.domain, index) - new_ranges = [] - new_dims = [] - new = self.ndarray[slices] - - for i, dim in enumerate(self.domain.dims): - if (pos := _find_index_of_dim(dim, index)) is not None: - index_or_range = index[pos][1] - if isinstance(index_or_range, common.UnitRange): - new_ranges.append(index_or_range) - new_dims.append(dim) - else: - # dimension not mentioned in slice - new_ranges.append(self.domain.ranges[i]) - new_dims.append(dim) - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( + self, other + ) + raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new) - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) - - def _getitem_relative_slice( - self, indices: tuple[slice | int | EllipsisType, ...] - ) -> common.Field | core_defs.ScalarT: - new = self.ndarray[indices] - new_dims = [] - new_ranges = [] - - for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - self.domain, _expand_ellipsis(indices, len(self.domain)), fillvalue=slice(None) - ): - if isinstance(idx, slice): - new_dims.append(dim) - new_ranges.append(_slice_range(rng, idx)) - else: - assert isinstance(idx, int) # not in new_domain - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + __rxor__ = __xor__ - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new), new - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + def __invert__(self) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + + def _slice( + self, index: common.AnyIndexSpec + ) -> tuple[common.Domain, common.RelativeIndexSequence]: + new_domain = embedded_common.sub_domain(self.domain, index) + + index_sequence = common.as_any_index_sequence(index) + slice_ = ( + _get_slices_from_domain_slice(self.domain, index_sequence) + if common.is_absolute_index_sequence(index_sequence) + else index_sequence + ) + assert common.is_relative_index_sequence(slice_) + return new_domain, slice_ # -- Specialized implementations for intrinsic operations on array fields -- @@ -266,6 +268,25 @@ def _getitem_relative_slice( fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] ) + +def _np_cp_setitem( + self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, +) -> None: + target_domain, target_slice = self._slice(index) + + if common.is_field(value): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self.ndarray, "__setitem__") + self.ndarray[target_slice] = value + + # -- Concrete array implementations -- # NumPy _nd_array_implementations = [np] @@ -275,6 +296,8 @@ def _getitem_relative_slice( class NumPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = np + __setitem__ = _np_cp_setitem + common.field.register(np.ndarray, NumPyArrayField.from_array) @@ -286,6 +309,8 @@ class NumPyArrayField(_BaseNdArrayField): class CuPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = cp + __setitem__ = _np_cp_setitem + common.field.register(cp.ndarray, CuPyArrayField.from_array) # JAX @@ -296,38 +321,30 @@ class CuPyArrayField(_BaseNdArrayField): class JaxArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = jnp - common.field.register(jnp.ndarray, JaxArrayField.from_array) - + def __setitem__( + self, + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, + ) -> None: + # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` + raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") -def _find_index_of_dim( - dim: common.Dimension, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> Optional[int]: - for i, (d, _) in enumerate(domain_slice): - if dim == d: - return i - return None + common.field.register(jnp.ndarray, JaxArrayField.from_array) def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice: list[slice | None] = [] - new_domain_dims = [] - new_domain_ranges = [] + named_ranges = [] for dim in new_dimensions: - if (pos := _find_index_of_dim(dim, field.domain)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - new_domain_dims.append(dim) - new_domain_ranges.append(field.domain[pos][1]) + named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - new_domain_dims.append(dim) - new_domain_ranges.append( - common.UnitRange(common.Infinity.negative(), common.Infinity.positive()) + named_ranges.append( + (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) ) - return common.field( - field.ndarray[tuple(domain_slice)], - domain=common.Domain(tuple(new_domain_dims), tuple(new_domain_ranges)), - ) + return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) def _builtins_broadcast( @@ -344,7 +361,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> tuple[slice | int | None, ...]: +) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -359,10 +376,10 @@ def _get_slices_from_domain_slice( specified in the Domain. If a dimension is not included in the named indices or ranges, a None is used to indicate expansion along that axis. """ - slice_indices: list[slice | int | None] = [] + slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): - if (pos := _find_index_of_dim(dim, domain_slice)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: @@ -370,7 +387,9 @@ def _get_slices_from_domain_slice( return tuple(slice_indices) -def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> slice | int: +def _compute_slice( + rng: common.UnitRange | common.IntIndex, domain: common.Domain, pos: int +) -> slice | common.IntIndex: """Compute a slice or integer based on the provided range, domain, and position. Args: @@ -392,34 +411,7 @@ def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> rng.start - domain.ranges[pos].start, rng.stop - domain.ranges[pos].start, ) - elif isinstance(rng, int): + elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") - - -def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: - # handle slice(None) case - if slice_obj == slice(None): - return common.UnitRange(input_range.start, input_range.stop) - - start = ( - input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop - ) + (slice_obj.start or 0) - stop = ( - input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop - ) + (slice_obj.stop or len(input_range)) - - return common.UnitRange(start, stop) - - -def _expand_ellipsis( - indices: tuple[int | slice | EllipsisType, ...], target_size: int -) -> tuple[int | slice, ...]: - expanded_indices: list[int | slice] = [] - for idx in indices: - if idx is Ellipsis: - expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) - else: - expanded_indices.append(idx) - return tuple(expanded_indices) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 74230263db..e956858549 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -33,17 +33,19 @@ from . import formatting -class DSLError(Exception): +class GT4PyError(Exception): + @property + def message(self) -> str: + return self.args[0] + + +class DSLError(GT4PyError): location: Optional[SourceLocation] def __init__(self, location: Optional[SourceLocation], message: str) -> None: self.location = location super().__init__(message) - @property - def message(self) -> str: - return self.args[0] - def with_location(self, location: Optional[SourceLocation]) -> DSLError: self.location = location return self diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index ba027be13c..52aae34b3f 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + import dataclasses import inspect from builtins import bool, float, int, tuple diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 0c5de764f2..006b3057b0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeGuard, TypeVar class RecursionGuard: @@ -49,3 +49,10 @@ def __enter__(self): def __exit__(self, *exc): self.guarded_objects.remove(id(self.obj)) + + +_T = TypeVar("_T") + + +def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: + return isinstance(v, tuple) and all(isinstance(e, t) for e in v) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py index 0bc5a98a4e..c1bee4fa2f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py @@ -82,7 +82,7 @@ def scan(inp: int32) -> int32: expected = textwrap.dedent( f""" - @scan_operator(axis=Dimension(value="KDim", kind=DimensionKind.VERTICAL), forward=False, init=1) + @scan_operator(axis=KDim[vertical], forward=False, init=1) def scan(inp: int32) -> int32: {ssa.unique_name("foo", 0)} = inp return inp diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py new file mode 100644 index 0000000000..640ed326bb --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -0,0 +1,137 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Sequence + +import pytest + +from gt4py.next import common +from gt4py.next.common import UnitRange +from gt4py.next.embedded import exceptions as embedded_exceptions +from gt4py.next.embedded.common import _slice_range, sub_domain + + +@pytest.mark.parametrize( + "rng, slce, expected", + [ + (UnitRange(2, 10), slice(2, -2), UnitRange(4, 8)), + (UnitRange(2, 10), slice(2, None), UnitRange(4, 10)), + (UnitRange(2, 10), slice(None, -2), UnitRange(2, 8)), + (UnitRange(2, 10), slice(None), UnitRange(2, 10)), + ], +) +def test_slice_range(rng, slce, expected): + result = _slice_range(rng, slce) + assert result == expected + + +I = common.Dimension("I") +J = common.Dimension("J") +K = common.Dimension("K") + + +@pytest.mark.parametrize( + "domain, index, expected", + [ + ([(I, (2, 5))], 1, []), + ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]), + ([(I, (2, 5))], (I, 2), []), + ([(I, (2, 5))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], 1, []), + ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]), + ([(I, (-2, 3))], (I, 1), []), + ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], -5, []), + ([(I, (-2, 3))], -6, IndexError), + ([(I, (-2, 3))], slice(-7, -6), IndexError), + ([(I, (-2, 3))], slice(-6, -7), IndexError), + ([(I, (-2, 3))], 4, []), + ([(I, (-2, 3))], 5, IndexError), + ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]), + ([(I, (-2, 3))], slice(5, 6), IndexError), + ([(I, (-2, 3))], (I, -3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), + ([(I, (-2, 3))], (I, 3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + 2, + [(J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + slice(2, 3), + [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, 2), + [(J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, UnitRange(2, 3)), + [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, 3), + [(I, (2, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, UnitRange(4, 5)), + [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, 3), (I, 2)), + [(K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, UnitRange(4, 5)), (I, 2)), + [(J, (4, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(2, 3)), + [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (Ellipsis, slice(2, 3)), + [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), Ellipsis, slice(2, 3)), + [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + ), + ], +) +def test_sub_domain(domain, index, expected): + domain = common.domain(domain) + if expected is IndexError: + with pytest.raises(embedded_exceptions.IndexOutOfBounds): + sub_domain(domain, index) + else: + expected = common.domain(expected) + result = sub_domain(domain, index) + assert result == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index a2aa3112bd..95093c8307 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -22,8 +22,8 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange -from gt4py.next.embedded import nd_array_field -from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice, _slice_range +from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field +from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -40,16 +40,42 @@ def nd_array_implementation(request): @pytest.fixture( - params=[operator.add, operator.sub, operator.mul, operator.truediv, operator.floordiv], + params=[ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + ] ) -def binary_op(request): +def binary_arithmetic_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation): +@pytest.fixture( + params=[operator.xor, operator.and_, operator.or_], +) +def binary_logical_op(request): + yield request.param + + +@pytest.fixture(params=[operator.neg, operator.pos]) +def unary_arithmetic_op(request): + yield request.param + + +@pytest.fixture(params=[operator.invert]) +def unary_logical_op(request): + yield request.param + + +def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): + if not dtype: + dtype = nd_array_implementation.float32 return common.field( - nd_array_implementation.asarray(lst, dtype=nd_array_implementation.float32), - domain=((common.Dimension("foo"), common.UnitRange(0, len(lst))),), + nd_array_implementation.asarray(lst, dtype=dtype), + domain={common.Dimension("foo"): (0, len(lst))}, ) @@ -72,16 +98,57 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) -def test_binary_ops(binary_op, nd_array_implementation): +def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] inputs = [inp_a, inp_b] - expected = binary_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) + expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] - result = binary_op(*field_inputs) + result = binary_arithmetic_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_binary_logical_ops(binary_logical_op, nd_array_implementation): + inp_a = [True, True, False, False] + inp_b = [True, False, True, False] + inputs = [inp_a, inp_b] + + expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) + + field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + + result = binary_logical_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_logical_ops(unary_logical_op, nd_array_implementation): + inp = [ + True, + False, + ] + + expected = unary_logical_op(np.asarray(inp)) + + field_input = _make_field(inp, nd_array_implementation, dtype=bool) + + result = unary_logical_op(field_input) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): + inp = [1.0, -2.0, 0.0] + + expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) + + field_input = _make_field(inp, nd_array_implementation) + + result = unary_arithmetic_op(field_input) assert np.allclose(result.ndarray, expected) @@ -93,7 +160,7 @@ def test_binary_ops(binary_op, nd_array_implementation): ((JDim,), (None, slice(5, 10))), ], ) -def test_binary_operations_with_intersection(binary_op, dims, expected_indices): +def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): arr1 = np.arange(10) arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) @@ -103,8 +170,8 @@ def test_binary_operations_with_intersection(binary_op, dims, expected_indices): field1 = common.field(arr1, domain=arr1_domain) field2 = common.field(arr2, domain=arr2_domain) - op_result = binary_op(field1, field2) - expected_result = binary_op(arr1[expected_indices[0], expected_indices[1]], arr2) + op_result = binary_arithmetic_op(field1, field2) + expected_result = binary_arithmetic_op(arr1[expected_indices[0], expected_indices[1]], arr2) assert op_result.ndarray.shape == (5, 5) assert np.allclose(op_result.ndarray, expected_result) @@ -122,10 +189,8 @@ def product_nd_array_implementation(request): def test_mixed_fields(product_nd_array_implementation): first_impl, second_impl = product_nd_array_implementation - if (first_impl.__name__ == "cupy" and second_impl.__name__ == "numpy") or ( - first_impl.__name__ == "numpy" and second_impl.__name__ == "cupy" - ): - pytest.skip("Binary operation between CuPy and NumPy requires explicit conversion.") + if "numpy" in first_impl.__name__ and "cupy" in second_impl.__name__: + pytest.skip("Binary operation between NumPy and CuPy requires explicit conversion.") inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] @@ -271,7 +336,7 @@ def test_get_slices_invalid_type(): (JDim, KDim), (2, 15), ), - ((IDim, 1), (JDim, KDim), (10, 15)), + ((IDim, 5), (JDim, KDim), (10, 15)), ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), ], ) @@ -282,20 +347,20 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): field = common.field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain.dims == expected_dimensions def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) - field = common.field(np.ones((10, 10), dtype=np.int32), domain=domain) + field = common.field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((IDim, 2), (JDim, 4)) + named_index = ((IDim, 12), (JDim, 6)) value = field[named_index] assert isinstance(value, np.int32) - assert value == 1 + assert value == 21 @pytest.mark.parametrize( @@ -304,19 +369,23 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 4))), + Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 4))), + ), + ((slice(None, 5),), (5, 10), Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 12)))), + ( + (Ellipsis, 1), + (10,), + Domain((IDim, UnitRange(5, 15))), ), - ((slice(None, 5),), (5, 10), Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 12)))), - ((Ellipsis, 1), (10,), Domain((IDim,), (UnitRange(5, 15),))), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((IDim, JDim), (UnitRange(7, 8), UnitRange(7, 9))), + Domain((IDim, UnitRange(7, 8)), (JDim, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((IDim,), (UnitRange(6, 7),)), + Domain((IDim, UnitRange(6, 7))), ), ], ) @@ -325,7 +394,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -333,32 +402,44 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_shape, expected_domain", [ - ((1, slice(None), 2), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((1, slice(None), 2), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( (slice(None), slice(None), 2), (10, 15), - Domain((IDim, JDim), (UnitRange(5, 15), UnitRange(10, 25))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(10, 25))), ), ( (slice(None),), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None), slice(None), slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), - ((0, Ellipsis, 0), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((0, Ellipsis, 0), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( Ellipsis, (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ], ) @@ -369,7 +450,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -391,7 +472,7 @@ def test_relative_indexing_out_of_bounds(lazy_slice): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common.field(np.ones((10, 10)), domain=domain) - with pytest.raises(IndexError): + with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) @@ -403,10 +484,40 @@ def test_field_unsupported_index(index): field[index] -def test_slice_range(): - input_range = UnitRange(2, 10) - slice_obj = slice(2, -2) - expected = UnitRange(4, 8) +@pytest.mark.parametrize( + "index, value", + [ + ((1, 1), 42.0), + ((1, slice(None)), np.ones((10,)) * 42.0), + ( + (1, slice(None)), + common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), + ), + ], +) +def test_setitem(index, value): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + expected = np.copy(field.ndarray) + expected[index] = value + + field[index] = value + + assert np.allclose(field.ndarray, expected) + + +def test_setitem_wrong_domain(): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + value_incompatible = common.field( + np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) + ) - result = _slice_range(input_range, slice_obj) - assert result == expected + with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + field[(1, slice(None))] = value_incompatible diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8cdc96254c..31e35221ab 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -15,7 +15,17 @@ import pytest -from gt4py.next.common import Dimension, DimensionKind, Domain, Infinity, UnitRange, promote_dims +from gt4py.next.common import ( + Dimension, + DimensionKind, + Domain, + Infinity, + UnitRange, + domain, + named_range, + promote_dims, + unit_range, +) IDim = Dimension("IDim") @@ -25,15 +35,8 @@ @pytest.fixture -def domain(): - range1 = UnitRange(0, 10) - range2 = UnitRange(5, 15) - range3 = UnitRange(20, 30) - - dimensions = (IDim, JDim, KDim) - ranges = (range1, range2, range3) - - return Domain(dimensions, ranges) +def a_domain(): + return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) def test_empty_range(): @@ -53,6 +56,11 @@ def test_unit_range_length(rng): assert len(rng) == 10 +@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) +def test_unit_range_like(rng_like): + assert unit_range(rng_like) == UnitRange(2, 4) + + def test_unit_range_repr(rng): assert repr(rng) == "UnitRange(-5, 5)" @@ -142,54 +150,87 @@ def test_mixed_infinity_range(): assert len(mixed_inf_range) == Infinity.positive() -def test_domain_length(domain): - assert len(domain) == 3 +@pytest.mark.parametrize( + "named_rng_like", + [ + (IDim, (2, 4)), + (IDim, range(2, 4)), + (IDim, UnitRange(2, 4)), + ], +) +def test_named_range_like(named_rng_like): + assert named_range(named_rng_like) == (IDim, UnitRange(2, 4)) + + +def test_domain_length(a_domain): + assert len(a_domain) == 3 -def test_domain_iteration(domain): - iterated_values = [val for val in domain] - assert iterated_values == list(zip(domain.dims, domain.ranges)) +@pytest.mark.parametrize( + "domain_like", + [ + (Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))), + ((IDim, (2, 4)), (JDim, (3, 5))), + ({IDim: (2, 4), JDim: (3, 5)}), + ], +) +def test_domain_like(domain_like): + assert domain(domain_like) == Domain( + dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)) + ) + +def test_domain_iteration(a_domain): + iterated_values = [val for val in a_domain] + assert iterated_values == list(zip(a_domain.dims, a_domain.ranges)) -def test_domain_contains_named_range(domain): - assert (IDim, UnitRange(0, 10)) in domain - assert (IDim, UnitRange(-5, 5)) not in domain + +def test_domain_contains_named_range(a_domain): + assert (IDim, UnitRange(0, 10)) in a_domain + assert (IDim, UnitRange(-5, 5)) not in a_domain @pytest.mark.parametrize( "second_domain, expected", [ ( - Domain((IDim, JDim), (UnitRange(2, 12), UnitRange(7, 17))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30)), + ), ), ( - Domain((IDim, KDim), (UnitRange(2, 12), UnitRange(7, 27))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27))), + Domain(dims=(IDim, KDim), ranges=(UnitRange(2, 12), UnitRange(7, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27)), + ), ), ( - Domain((JDim, KDim), (UnitRange(2, 12), UnitRange(4, 27))), - Domain((IDim, JDim, KDim), (UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27))), + Domain(dims=(JDim, KDim), ranges=(UnitRange(2, 12), UnitRange(4, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27)), + ), ), ], ) -def test_domain_intersection_different_dimensions(domain, second_domain, expected): - result_domain = domain & second_domain +def test_domain_intersection_different_dimensions(a_domain, second_domain, expected): + result_domain = a_domain & second_domain print(result_domain) assert result_domain == expected -def test_domain_intersection_reversed_dimensions(domain): - dimensions = (JDim, IDim) - ranges = (UnitRange(2, 12), UnitRange(7, 17)) - domain2 = Domain(dimensions, ranges) +def test_domain_intersection_reversed_dimensions(a_domain): + domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) with pytest.raises( ValueError, match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", ): - domain & domain2 + a_domain & domain2 @pytest.mark.parametrize( @@ -202,8 +243,8 @@ def test_domain_intersection_reversed_dimensions(domain): (-2, (JDim, UnitRange(5, 15))), ], ) -def test_domain_integer_indexing(domain, index, expected): - result = domain[index] +def test_domain_integer_indexing(a_domain, index, expected): + result = a_domain[index] assert result == expected @@ -214,8 +255,8 @@ def test_domain_integer_indexing(domain, index, expected): (slice(1, None), ((JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30)))), ], ) -def test_domain_slice_indexing(domain, slice_obj, expected): - result = domain[slice_obj] +def test_domain_slice_indexing(a_domain, slice_obj, expected): + result = a_domain[slice_obj] assert isinstance(result, Domain) assert len(result) == len(expected) assert all(res == exp for res, exp in zip(result, expected)) @@ -228,28 +269,28 @@ def test_domain_slice_indexing(domain, slice_obj, expected): (KDim, (KDim, UnitRange(20, 30))), ], ) -def test_domain_dimension_indexing(domain, index, expected_result): - result = domain[index] +def test_domain_dimension_indexing(a_domain, index, expected_result): + result = a_domain[index] assert result == expected_result -def test_domain_indexing_dimension_missing(domain): +def test_domain_indexing_dimension_missing(a_domain): with pytest.raises(KeyError, match=r"No Dimension of type .* is present in the Domain."): - domain[ECDim] + a_domain[ECDim] -def test_domain_indexing_invalid_type(domain): +def test_domain_indexing_invalid_type(a_domain): with pytest.raises( KeyError, match="Invalid index type, must be either int, slice, or Dimension." ): - domain["foo"] + a_domain["foo"] def test_domain_repeat_dims(): dims = (IDim, JDim, IDim) ranges = (UnitRange(0, 5), UnitRange(0, 8), UnitRange(0, 3)) with pytest.raises(NotImplementedError, match=r"Domain dimensions must be unique, not .*"): - Domain(dims, ranges) + Domain(dims=dims, ranges=ranges) def test_domain_dims_ranges_length_mismatch(): From 65d1ccbf1d562e3ad22956f4c0e7d2d2abd32b19 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" <18477+egparedes@users.noreply.github.com> Date: Thu, 7 Sep 2023 11:07:49 +0200 Subject: [PATCH 10/67] fix[cartesian]: use distutils from setuptools instead of the standard library (#1334) This PR fixes the imports in the module used to compile python extensions to use the `distutils` versions packaged into `setuptools` instead of the version from the standard library, which is deprecated and causes some warnings at compilation which can break other tools (e.g. spack). --- src/gt4py/cartesian/backend/pyext_builder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index d2aa34bdae..e12669ae0f 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -14,8 +14,6 @@ import contextlib import copy -import distutils -import distutils.sysconfig import io import os import shutil @@ -23,6 +21,7 @@ import pybind11 import setuptools +from setuptools import distutils from setuptools.command.build_ext import build_ext from gt4py.cartesian import config as gt_config From b64fdab8ed2a03ff1306a9cdd1bda6bc468e794f Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 11 Sep 2023 11:39:18 +0200 Subject: [PATCH 11/67] feat[next]: DaCe support for floordiv (#1337) This small PR adds support in DaCe backend for floordiv math built-in function. It also enables some tests for math built-in execution on DaCe backend. --- .../runners/dace_iterator/itir_to_tasklet.py | 1 + .../ffront_tests/test_execution.py | 2 -- .../test_math_builtin_execution.py | 2 -- .../ffront_tests/test_math_unary_builtins.py | 19 ------------------- .../test_with_toy_connectivity.py | 6 ++---- 5 files changed, 3 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 875a23353b..d301c3e3cf 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -96,6 +96,7 @@ def itir_type_as_dace_type(type_: next_typing.Type): "minus": "({} - {})", "multiplies": "({} * {})", "divides": "({} / {})", + "floordiv": "({} // {})", "eq": "({} == {})", "not_eq": "({} != {})", "less": "({} < {})", diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9f284f4041..b44cbb8181 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -781,8 +781,6 @@ def program_domain(a: cases.IField, out: cases.IField): def test_domain_input_bounds(cartesian_case): if cartesian_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: pytest.xfail("FloorDiv not fully supported in gtfn.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: type inference failure") lower_i = 1 upper_i = 10 diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index b484fc6f31..9ceab7f2d0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -117,8 +117,6 @@ def make_builtin_field_operator(builtin_name: str): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inputs): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") if builtin_name == "gamma": # numpy has no gamma function ref_impl: Callable = np.vectorize(math.gamma) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 6c7dbee855..23afc707b0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -60,9 +60,6 @@ def arithmetic(inp1: cases.IFloatField, inp2: cases.IFloatField) -> gtx.Field[[I def test_power(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") - @gtx.field_operator def pow(inp1: cases.IField) -> cases.IField: return inp1**2 @@ -74,7 +71,6 @@ def test_floordiv(cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, - dace_iterator.run_dace_iterator, ]: pytest.xfail( "FloorDiv not yet supported." @@ -201,9 +197,6 @@ def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: def test_basic_trig(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") - @gtx.field_operator def basic_trig_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFloatField: return sin(cos(inp1)) - sinh(cosh(inp2)) + tan(inp1) - tanh(inp2) @@ -219,9 +212,6 @@ def basic_trig_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> case def test_exp_log(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") - @gtx.field_operator def exp_log_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFloatField: return log(inp1) - exp(inp2) @@ -232,9 +222,6 @@ def exp_log_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.I def test_roots(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") - @gtx.field_operator def roots_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFloatField: return sqrt(inp1) - cbrt(inp2) @@ -245,9 +232,6 @@ def roots_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFl def test_is_values(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") - @gtx.field_operator def is_isinf_fieldop(inp1: cases.IFloatField) -> cases.IBoolField: return isinf(inp1) @@ -274,9 +258,6 @@ def is_isfinite_fieldop(inp1: cases.IFloatField) -> cases.IBoolField: def test_rounding_funs(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Bug in type inference with math builtins, breaks dace backend.") - @gtx.field_operator def rounding_funs_fieldop( inp1: cases.IFloatField, inp2: cases.IFloatField diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index b0d04d4379..ded65bceaa 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -168,10 +168,8 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) -def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil( - program_processor_no_dace_exec, lift_mode -): - program_processor, validate = program_processor_no_dace_exec +def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor, lift_mode): + program_processor, validate = program_processor inp = vertex_index_field() out = gtx.np_as_located_field(Cell)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(v2e_arr[c[0]][0] for c in c2e_arr)) From 5157007c1b77f1b8e4aa458941ab3b85c1c01197 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 14 Sep 2023 08:00:19 +0200 Subject: [PATCH 12/67] feat[next]: iterator.embedded with new Field implementation (#1308) Makes iterator.embedded work with the new `common.Field` implementation. - Temporarily `iterator.embedded.np_as_located_field` returns `common.Field`. In a next step all users should switch to a different allocation function. - `iterator.embedded` wraps the `common.Field` into a `LocatedField` wrapper to do some translation between `iterator.embedded` `field_getitem`/`field_setitem`. This layer should eventually be removed. - Adds a minimal (= only with requirements for `iterator.embedded`) implementation of `IndexField` and `ConstantField` to `iterator.embedded`. These can be generalized to proper `Field`s suitable for fieldview embedded. --- src/gt4py/next/iterator/embedded.py | 562 +++++++++++------- src/gt4py/next/iterator/tracing.py | 6 +- .../runners/dace_iterator/__init__.py | 8 +- .../program_processors/runners/gtfn_cpu.py | 7 +- .../program_processors/runners/roundtrip.py | 21 +- .../next/type_system/type_translation.py | 6 +- tests/next_tests/integration_tests/cases.py | 15 +- .../ffront_tests/test_execution.py | 32 +- .../ffront_tests/test_gt4py_builtins.py | 7 +- .../ffront_tests/test_math_unary_builtins.py | 12 +- .../ffront_tests/test_program.py | 6 +- .../ffront_tests/test_scalar_if.py | 2 +- .../iterator_tests/test_builtins.py | 8 + .../iterator_tests/test_conditional.py | 4 +- .../test_horizontal_indirection.py | 4 +- .../iterator_tests/test_implicit_fencil.py | 6 +- .../feature_tests/iterator_tests/test_scan.py | 4 +- .../iterator_tests/test_tuple.py | 151 +---- .../feature_tests/test_util_cases.py | 20 +- .../ffront_tests/test_icon_like_scan.py | 2 +- .../ffront_tests/test_laplacian.py | 8 +- .../iterator_tests/test_column_stencil.py | 102 +++- .../iterator_tests/test_fvm_nabla.py | 48 +- .../iterator_tests/test_vertical_advection.py | 2 +- .../iterator_tests/test_embedded_field.py | 79 --- 25 files changed, 556 insertions(+), 566 deletions(-) delete mode 100644 tests/next_tests/unit_tests/iterator_tests/test_embedded_field.py diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3147ede387..0edea35cf5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -22,6 +22,7 @@ import dataclasses import itertools import math +import sys from typing import ( Any, Callable, @@ -29,6 +30,7 @@ Iterable, Literal, Mapping, + NoReturn, Optional, Protocol, Sequence, @@ -46,8 +48,10 @@ import numpy as np import numpy.typing as npt +from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.next import common +from gt4py.next.embedded import exceptions as embedded_exceptions from gt4py.next.iterator import builtins, runtime @@ -56,13 +60,12 @@ # Atoms Tag: TypeAlias = str -IntIndex: TypeAlias = int | np.integer -ArrayIndex: TypeAlias = slice | IntIndex +ArrayIndex: TypeAlias = slice | common.IntIndex ArrayIndexOrIndices: TypeAlias = ArrayIndex | tuple[ArrayIndex, ...] FieldIndex: TypeAlias = ( - range | slice | IntIndex + range | slice | common.IntIndex ) # A `range` FieldIndex can be negative indicating a relative position with respect to origin, not wrap-around semantics like `slice` TODO(havogt): remove slice here FieldIndices: TypeAlias = tuple[FieldIndex, ...] FieldIndexOrIndices: TypeAlias = FieldIndex | FieldIndices @@ -96,7 +99,9 @@ def __init__( self.has_skip_values = has_skip_values self.index_type = table.dtype - def mapped_index(self, primary: IntIndex, neighbor_idx: IntIndex) -> IntIndex: + def mapped_index( + self, primary: common.IntIndex, neighbor_idx: common.IntIndex + ) -> common.IntIndex: return self.table[(primary, neighbor_idx)] @@ -114,21 +119,23 @@ def __init__( self.has_skip_values = has_skip_values self.index_type = int - def mapped_index(self, primary: IntIndex, neighbor_idx: IntIndex) -> IntIndex: + def mapped_index( + self, primary: common.IntIndex, neighbor_idx: common.IntIndex + ) -> common.IntIndex: return primary * self.max_neighbors + neighbor_idx # Offsets -OffsetPart: TypeAlias = Tag | IntIndex -CompleteOffset: TypeAlias = tuple[Tag, IntIndex] +OffsetPart: TypeAlias = Tag | common.IntIndex +CompleteOffset: TypeAlias = tuple[Tag, common.IntIndex] OffsetProviderElem: TypeAlias = common.Dimension | common.Connectivity OffsetProvider: TypeAlias = dict[Tag, OffsetProviderElem] # Positions SparsePositionEntry = list[int] IncompleteSparsePositionEntry: TypeAlias = list[Optional[int]] -PositionEntry: TypeAlias = SparsePositionEntry | IntIndex -IncompletePositionEntry: TypeAlias = IncompleteSparsePositionEntry | IntIndex +PositionEntry: TypeAlias = SparsePositionEntry | common.IntIndex +IncompletePositionEntry: TypeAlias = IncompleteSparsePositionEntry | common.IntIndex ConcretePosition: TypeAlias = dict[Tag, PositionEntry] IncompletePosition: TypeAlias = dict[Tag, IncompletePositionEntry] @@ -139,17 +146,6 @@ def mapped_index(self, primary: IntIndex, neighbor_idx: IntIndex) -> IntIndex: NamedFieldIndices: TypeAlias = Mapping[Tag, FieldIndex | SparsePositionEntry] -def is_int_index(p: Any) -> TypeGuard[IntIndex]: - return isinstance(p, (int, np.integer)) - - -def _tupelize(tup): - if isinstance(tup, tuple): - return tup - else: - return (tup,) - - @runtime_checkable class ItIterator(Protocol): """ @@ -481,7 +477,7 @@ def promote_scalars(val: CompositeOfScalarOrField): """Given a scalar, field or composite thereof promote all (contained) scalars to fields.""" if isinstance(val, tuple): return tuple(promote_scalars(el) for el in val) - elif isinstance(val, LocatedField): + elif common.is_field(val): return val val_type = infer_dtype_like_type(val) if isinstance(val, Scalar): # type: ignore # mypy bug @@ -509,20 +505,6 @@ def promote_scalars(val: CompositeOfScalarOrField): globals()[math_builtin_name] = decorator(impl) -def _lookup_offset_provider(offset_provider: OffsetProvider, tag: Tag) -> OffsetProviderElem: - if tag not in offset_provider: - raise RuntimeError(f"Missing offset provider for `{tag}`") - return offset_provider[tag] - - -def _get_connectivity(offset_provider: OffsetProvider, tag: Tag) -> common.Connectivity: - if not isinstance( - connectivity := _lookup_offset_provider(offset_provider, tag), common.Connectivity - ): - raise RuntimeError(f"Expected a `Connectivity` for `{tag}`") - return connectivity - - def _named_range(axis: str, range_: Iterable[int]) -> Iterable[CompleteOffset]: return ((axis, i) for i in range_) @@ -535,7 +517,7 @@ def _domain_iterator(domain: dict[Tag, range]) -> Iterable[Position]: def execute_shift( - pos: Position, tag: Tag, index: IntIndex, *, offset_provider: OffsetProvider + pos: Position, tag: Tag, index: common.IntIndex, *, offset_provider: OffsetProvider ) -> MaybePosition: assert pos is not None if isinstance(tag, SparseTag): @@ -558,7 +540,7 @@ def execute_shift( offset_implementation = offset_provider[tag] if isinstance(offset_implementation, common.Dimension): new_pos = copy.copy(pos) - if is_int_index(value := new_pos[offset_implementation.value]): + if common.is_int_index(value := new_pos[offset_implementation.value]): new_pos[offset_implementation.value] = value + index else: raise AssertionError() @@ -569,7 +551,7 @@ def execute_shift( new_pos = pos.copy() new_pos.pop(offset_implementation.origin_axis.value) cur_index = pos[offset_implementation.origin_axis.value] - assert is_int_index(cur_index) + assert common.is_int_index(cur_index) if offset_implementation.mapped_index(cur_index, index) in [ None, -1, @@ -623,6 +605,12 @@ class Undefined: def __float__(self): return np.nan + def __int__(self): + return sys.maxsize + + def __repr__(self): + return "_UNDEFINED" + @classmethod def _setup_math_operations(cls): ops = [ @@ -694,17 +682,18 @@ def _get_axes( def _single_vertical_idx( - indices: NamedFieldIndices, column_axis: Tag, column_index: IntIndex + indices: NamedFieldIndices, column_axis: Tag, column_index: common.IntIndex ) -> NamedFieldIndices: transformed = { - axis: (index if axis != column_axis else column_index) for axis, index in indices.items() + axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis` + for axis, index in indices.items() } return transformed @overload def _make_tuple( - field_or_tuple: tuple[tuple | LocatedField, ...], # arbitrary nesting of tuples of LocatedField + field_or_tuple: tuple[tuple | LocatedField, ...], # arbitrary nesting of tuples of Field named_indices: NamedFieldIndices, *, column_axis: Tag, @@ -714,11 +703,11 @@ def _make_tuple( @overload def _make_tuple( - field_or_tuple: tuple[tuple | LocatedField, ...], # arbitrary nesting of tuples of LocatedField + field_or_tuple: tuple[tuple | LocatedField, ...], # arbitrary nesting of tuples of Field named_indices: NamedFieldIndices, *, column_axis: Literal[None] = None, -) -> tuple[tuple | npt.DTypeLike, ...]: # arbitrary nesting +) -> tuple[tuple | npt.DTypeLike | Undefined, ...]: # arbitrary nesting ... @@ -735,7 +724,7 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Literal[None] = None, -) -> npt.DTypeLike: +) -> npt.DTypeLike | Undefined: ... @@ -744,48 +733,57 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Optional[Tag] = None, -) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike, ...]: - column_range = column_range_cvar.get() - if isinstance(field_or_tuple, tuple): - if column_axis is not None: - assert column_range - # construct a Column of tuples - first = tuple( - _make_tuple(f, _single_vertical_idx(named_indices, column_axis, column_range.start)) - for f in field_or_tuple - ) - col = Column( - column_range.start, np.zeros(len(column_range), dtype=_column_dtype(first)) - ) - col[0] = first - for i in column_range[1:]: - col[i] = tuple( - _make_tuple(f, _single_vertical_idx(named_indices, column_axis, i)) - for f in field_or_tuple - ) - return col - else: +) -> Column | npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike | Undefined, ...] | Undefined: + if column_axis is None: + if isinstance(field_or_tuple, tuple): return tuple(_make_tuple(f, named_indices) for f in field_or_tuple) - else: - data = field_or_tuple.field_getitem(named_indices) - if column_axis is not None: - # wraps a vertical slice of an input field into a `Column` - assert column_range is not None - return Column(column_range.start, data) else: - return data + try: + data = field_or_tuple.field_getitem(named_indices) + return data + except embedded_exceptions.IndexOutOfBounds: + return _UNDEFINED + else: + column_range = column_range_cvar.get() + assert column_range is not None + col: list[ + npt.DTypeLike | tuple[tuple | Column | npt.DTypeLike | Undefined, ...] | Undefined + ] = [] + for i in column_range: + # we don't know the buffer size, therefore we have to try. + try: + col.append( + tuple( + _make_tuple( + f, + _single_vertical_idx( + named_indices, column_axis, i - column_range.start + ), + ) + for f in field_or_tuple + ) + if isinstance(field_or_tuple, tuple) + else _make_tuple( + field_or_tuple, + _single_vertical_idx(named_indices, column_axis, i - column_range.start), + ) + ) + except embedded_exceptions.IndexOutOfBounds: + col.append(_UNDEFINED) -def _axis_idx(axes: Sequence[common.Dimension], axis: Tag) -> Optional[int]: - for i, a in enumerate(axes): - if a.value == axis: - return i - return None + first = next((v for v in col if v != _UNDEFINED), None) + if first is None: + raise RuntimeError( + "Found 'Undefined' value, this should not happen for a legal program." + ) + dtype = _column_dtype(first) + return Column(column_range.start, np.asarray(col, dtype=dtype)) @dataclasses.dataclass(frozen=True) class MDIterator: - field: LocatedField + field: LocatedField | tuple[LocatedField | tuple, ...] # arbitrary nesting pos: MaybePosition column_axis: Optional[Tag] = dataclasses.field(default=None, kw_only=True) @@ -842,12 +840,21 @@ def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[Tag]: ] +def _wrap_field(field: common.Field | tuple) -> NDArrayLocatedFieldWrapper | tuple: + if isinstance(field, tuple): + return tuple(_wrap_field(f) for f in field) + else: + assert common.is_field(field) + return NDArrayLocatedFieldWrapper(field) + + def make_in_iterator( - inp: LocatedField, + inp_: common.Field, pos: Position, *, column_axis: Optional[Tag], ) -> ItIterator: + inp = _wrap_field(inp_) axes = _get_axes(inp) sparse_dimensions = _get_sparse_dimensions(axes) new_pos: Position = pos.copy() @@ -878,60 +885,49 @@ def make_in_iterator( builtins.builtin_dispatch.push_key(EMBEDDED) # makes embedded the default -class LocatedFieldImpl(MutableLocatedField): - """A Field with named dimensions/axes.""" +@dataclasses.dataclass(frozen=True) +class NDArrayLocatedFieldWrapper(MutableLocatedField): + """A temporary helper until we sorted out all Field conventions between frontend and iterator.embedded.""" + + _ndarrayfield: common.Field @property def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._axes + return self._ndarrayfield.__gt_dims__ + + def _translate_named_indices( + self, _named_indices: NamedFieldIndices + ) -> common.AbsoluteIndexSequence: + named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = { + d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__ + } + domain_slice: list[common.NamedRange | common.NamedIndex] = [] + for d, v in named_indices.items(): + if isinstance(v, range): + domain_slice.append((d, common.UnitRange(v.start, v.stop))) + elif isinstance(v, list): + assert len(v) == 1 # only 1 sparse dimension is supported + assert common.is_int_index( + v[0] + ) # derefing a concrete element in a sparse field, not a slice + domain_slice.append((d, v[0])) + else: + assert common.is_int_index(v) + domain_slice.append((d, v)) + return tuple(domain_slice) - def __init__( - self, - getter: Callable[[FieldIndexOrIndices], Any], - axes: tuple[common.Dimension, ...], - dtype, - *, - setter: Callable[[FieldIndexOrIndices, Any], None], - array: Callable[[], npt.NDArray], - origin: Optional[dict[common.Dimension, int]] = None, - ): - self.getter = getter - self._axes = axes - self.setter = setter - self.array = array - self.dtype = dtype - self.origin = origin - - def __getitem__(self, indices: ArrayIndexOrIndices) -> Any: - return self.array()[indices] - - # TODO in a stable implementation of the Field concept we should make this behavior the default behavior for __getitem__ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: - return self.getter(get_ordered_indices(self._axes, named_indices)) - - def __setitem__(self, indices: ArrayIndexOrIndices, value: Any): - self.array()[indices] = value + return self._ndarrayfield[self._translate_named_indices(named_indices)] def field_setitem(self, named_indices: NamedFieldIndices, value: Any): - self.setter(get_ordered_indices(self._axes, named_indices), value) - - def __array__(self) -> np.ndarray: - return self.array() + if common.is_mutable_field(self._ndarrayfield): + self._ndarrayfield[self._translate_named_indices(named_indices)] = value + else: + raise RuntimeError("Assigment into a non-mutable Field.") @property def __gt_origin__(self) -> tuple[int, ...]: - if not self.origin: - return tuple([0] * len(self.__gt_dims__)) - return cast( - tuple[int], - get_ordered_indices(self.__gt_dims__, {k.value: v for k, v in self.origin.items()}), - ) - - @property - def shape(self): - if self.array is None: - raise TypeError("`shape` not supported for this field") - return self.array().shape + return self._ndarrayfield.__gt_origin__ def _is_field_axis(axis: Axis) -> TypeGuard[FieldAxis]: @@ -975,17 +971,17 @@ def _shift_range(range_or_index: range, offset: int) -> slice: @overload -def _shift_range(range_or_index: IntIndex, offset: int) -> IntIndex: +def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: ... -def _shift_range(range_or_index: range | IntIndex, offset: int) -> ArrayIndex: +def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayIndex: if isinstance(range_or_index, range): # range_or_index describes a range in the field assert range_or_index.step == 1 return slice(range_or_index.start + offset, range_or_index.stop + offset) else: - assert is_int_index(range_or_index) + assert common.is_int_index(range_or_index) return range_or_index + offset @@ -995,11 +991,11 @@ def _range2slice(r: range) -> slice: @overload -def _range2slice(r: IntIndex) -> IntIndex: +def _range2slice(r: common.IntIndex) -> common.IntIndex: ... -def _range2slice(r: range | IntIndex) -> slice | IntIndex: +def _range2slice(r: range | common.IntIndex) -> slice | common.IntIndex: if isinstance(r, range): assert r.start >= 0 and r.stop >= r.start return slice(r.start, r.stop) @@ -1007,7 +1003,7 @@ def _range2slice(r: range | IntIndex) -> slice | IntIndex: def _shift_field_indices( - ranges_or_indices: tuple[range | IntIndex, ...], + ranges_or_indices: tuple[range | common.IntIndex, ...], offsets: tuple[int, ...], ) -> tuple[ArrayIndex, ...]: return tuple( @@ -1018,74 +1014,231 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None -) -> Callable[[np.ndarray], LocatedFieldImpl]: - def _maker(a: np.ndarray) -> LocatedFieldImpl: +) -> Callable[[np.ndarray], common.Field]: + origin = origin or {} + + def _maker(a) -> common.Field: if a.ndim != len(axes): - raise TypeError("ndarray.ndim incompatible with number of given axes") + raise TypeError("ndarray.ndim incompatible with number of given dimensions") + ranges = [] + for d, s in zip(axes, a.shape): + offset = origin.get(d, 0) + ranges.append(common.UnitRange(-offset, s - offset)) - if origin is not None: - offsets = get_ordered_indices(axes, {k.value: v for k, v in origin.items()}) - else: - offsets = None - - def setter(indices, value): - indices = _tupelize(indices) - a[_shift_field_indices(indices, offsets) if offsets else indices] = value - - def getter(indices): - return a[_shift_field_indices(indices, offsets) if offsets else indices] - - return LocatedFieldImpl( - getter, - axes, - dtype=a.dtype, - setter=setter, - array=a.__array__, - origin=origin, - ) + res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges))) + return res return _maker -class IndexField(LocatedField): - def __init__(self, axis: common.Dimension, dtype: npt.DTypeLike) -> None: - self.axis = axis - self.dtype = np.dtype(dtype) +@dataclasses.dataclass(frozen=True) +class IndexField(common.Field): + """ + Minimal index field implementation. - def field_getitem(self, named_indices: NamedFieldIndices) -> Any: - index = get_ordered_indices(self.__gt_dims__, named_indices) - if isinstance(index, int): - return self.dtype.type(index) - else: - assert isinstance(index, tuple) and len(index) == 1 and isinstance(index[0], int) - return self.dtype.type(index[0]) + TODO: Improve implementation (e.g. support slicing) and move out of this module. + """ + + _dimension: common.Dimension + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return (self._dimension,) + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return (0,) + + @classmethod + def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype + raise NotImplementedError() + + @property + def domain(self) -> common.Domain: + return common.Domain((self._dimension, common.UnitRange.infinity())) + + @property + def dtype(self) -> core_defs.Int32DType: + return core_defs.Int32DType() @property - def __gt_dims__(self) -> tuple[common.Dimension]: - return (self.axis,) + def ndarray(self) -> core_defs.NDArrayObject: + return AttributeError("Cannot get `ndarray` of an infinite Field.") + + def remap(self, index_field: common.Field) -> common.Field: + # TODO can be implemented by constructing and ndarray (but do we know of which kind?) + raise NotImplementedError() + + def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32: + if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code + d, r = item[0] + assert d == self._dimension + assert isinstance(r, int) + return self.dtype.scalar_type(r) + # TODO set a domain... + raise NotImplementedError() + + __call__ = remap + __getitem__ = restrict + def __abs__(self) -> common.Field: + raise NotImplementedError() -def index_field(axis: common.Dimension, dtype: npt.DTypeLike = np.int32) -> LocatedField: - return IndexField(axis, dtype) + def __neg__(self) -> common.Field: + raise NotImplementedError() + def __invert__(self) -> common.Field: + raise NotImplementedError() -class ConstantField(LocatedField): - def __init__(self, value: Any, dtype: npt.DTypeLike): - self.value = value - self.dtype = np.dtype(dtype).type + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() - def field_getitem(self, _: NamedFieldIndices) -> Any: - return self.dtype(self.value) + def __radd__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __sub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rsub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __mul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rmul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __floordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rfloordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __truediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __pow__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + +def index_field(axis: common.Dimension) -> common.Field: + return IndexField(axis) + + +@dataclasses.dataclass(frozen=True) +class ConstantField(common.Field[Any, core_defs.ScalarT]): + """ + Minimal constant field implementation. + + TODO: Improve implementation (e.g. support slicing) and move out of this module. + """ + + _value: core_defs.ScalarT + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return tuple() + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple() + + @classmethod + def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype + raise NotImplementedError() + + @property + def domain(self) -> common.Domain: + return common.Domain(dims=(), ranges=()) + + @property + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + return core_defs.dtype(type(self._value)) @property - def __gt_dims__(self) -> tuple[()]: - return () + def ndarray(self) -> core_defs.NDArrayObject: + return AttributeError("Cannot get `ndarray` of an infinite Field.") + def remap(self, index_field: common.Field) -> common.Field: + # TODO can be implemented by constructing and ndarray (but do we know of which kind?) + raise NotImplementedError() -def constant_field(value: Any, dtype: Optional[npt.DTypeLike] = None) -> LocatedField: - if dtype is None: - dtype = infer_dtype_like_type(value) - return ConstantField(value, dtype) + def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + # TODO set a domain... + return self._value + + __call__ = remap + __getitem__ = restrict + + def __abs__(self) -> common.Field: + raise NotImplementedError() + + def __neg__(self) -> common.Field: + raise NotImplementedError() + + def __invert__(self) -> common.Field: + raise NotImplementedError() + + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __radd__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __sub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rsub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __mul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rmul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __floordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rfloordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __truediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __pow__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + +def constant_field(value: Any, dtype_like: Optional[core_defs.DTypeLike] = None) -> common.Field: + if dtype_like is None: + dtype_like = infer_dtype_like_type(value) + dtype = core_defs.dtype(dtype_like) + return ConstantField(dtype.scalar_type(value)) @builtins.shift.register(EMBEDDED) @@ -1216,42 +1369,27 @@ def shift(self, *offsets: OffsetPart) -> ScanArgIterator: def shifted_scan_arg(k_pos: int) -> Callable[[ItIterator], ScanArgIterator]: def impl(it: ItIterator) -> ScanArgIterator: - return ScanArgIterator(it, k_pos=k_pos) + return ScanArgIterator(it, k_pos=k_pos) # here we evaluate the full column in every step return impl def is_located_field(field: Any) -> TypeGuard[LocatedField]: - return isinstance(field, LocatedField) # TODO(havogt): avoid isinstance on Protocol + return isinstance(field, LocatedField) def is_mutable_located_field(field: Any) -> TypeGuard[MutableLocatedField]: - return isinstance(field, MutableLocatedField) # TODO(havogt): avoid isinstance on Protocol - - -def has_uniform_tuple_element(field) -> bool: - return field.dtype.fields is not None and all( - next(iter(field.dtype.fields))[0] == f[0] for f in iter(field.dtype.fields) - ) + return isinstance(field, MutableLocatedField) def is_tuple_of_field(field) -> bool: return isinstance(field, tuple) and all( - is_located_field(f) or is_tuple_of_field(f) for f in field + common.is_field(f) or is_tuple_of_field(f) for f in field ) -def is_field_of_tuple(field) -> bool: - return is_located_field(field) and has_uniform_tuple_element(field) - - -def can_be_tuple_field(field) -> bool: - return is_tuple_of_field(field) or is_field_of_tuple(field) - - class TupleFieldMeta(type): - def __instancecheck__(self, arg): - return super().__instancecheck__(arg) or is_field_of_tuple(arg) + ... class TupleField(metaclass=TupleFieldMeta): @@ -1283,8 +1421,6 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): - if not is_tuple_of_field(data): - raise TypeError("Can only be instantiated with a tuple of fields") self.data = data self.__gt_dims__ = _get_axes(data) @@ -1298,13 +1434,9 @@ def field_setitem(self, named_indices: NamedFieldIndices, value: Any): def as_tuple_field(field: tuple | TupleField) -> TupleField: - assert can_be_tuple_field(field) - - if is_tuple_of_field(field): - return TupleOfFields(field) - - assert isinstance(field, TupleField) # e.g. field of tuple is already TupleField - return field + assert is_tuple_of_field(field) + assert not isinstance(field, TupleField) + return TupleOfFields(tuple(_wrap_field(f) for f in field)) def _column_dtype(elem: Any) -> np.dtype: @@ -1355,12 +1487,12 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): def closure( domain_: Domain, sten: Callable[..., Any], - out: MutableLocatedField, - ins: list[LocatedField], + out, #: MutableLocatedField, + ins: list[common.Field], ) -> None: _validate_domain(domain_, kwargs["offset_provider"]) domain: dict[Tag, range] = _dimension_to_tag(domain_) - if not (is_located_field(out) or can_be_tuple_field(out)): + if not (common.is_field(out) or is_tuple_of_field(out)): raise TypeError("Out needs to be a located field.") column_range = None @@ -1372,13 +1504,7 @@ def closure( column_range = column.col_range - out = ( - as_tuple_field( # type:ignore[assignment] - out # type:ignore[arg-type] # TODO(havogt) improve the code around TupleField construction - ) - if can_be_tuple_field(out) - else out - ) + out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) def _closure_runner(): # Set context variables before executing the closure diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index b26f59cfd4..fbe6a2ae82 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -21,7 +21,6 @@ from gt4py.eve import Node from gt4py.next import common, iterator from gt4py.next.iterator import builtins, ir_makers as im -from gt4py.next.iterator.embedded import LocatedField from gt4py.next.iterator.ir import ( AxisLiteral, Expr, @@ -253,9 +252,8 @@ def _contains_tuple_dtype_field(arg): # various implementations have different behaviour (some return e.g. `np.dtype("int32")` # other `np.int32`). We just ignore the error here and postpone fixing this to when # the new storages land (The implementation here works for LocatedFieldImpl). - return isinstance(arg, LocatedField) and ( - arg.dtype.fields is not None or any(dim is None for dim in arg.__gt_dims__) - ) + + return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__) def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 718537713e..f78d90095c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -18,7 +18,8 @@ import numpy as np import gt4py.next.iterator.ir as itir -from gt4py.next.iterator.embedded import LocatedField, NeighborTableOffsetProvider +from gt4py.next import common +from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.processor_interface import program_executor @@ -29,11 +30,12 @@ def convert_arg(arg: Any): - if isinstance(arg, LocatedField): + if common.is_field(arg): sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) ndim = len(sorted_dims) dim_indices = [dim[0] for dim in sorted_dims] - return np.moveaxis(np.asarray(arg), range(ndim), dim_indices) + assert isinstance(arg.ndarray, np.ndarray) + return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn_cpu.py index 195126a6ba..72f2637d87 100644 --- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py +++ b/src/gt4py/next/program_processors/runners/gtfn_cpu.py @@ -14,7 +14,6 @@ from typing import Any -import numpy as np import numpy.typing as npt from gt4py.eve.utils import content_hash @@ -32,9 +31,9 @@ def convert_arg(arg: Any) -> Any: if isinstance(arg, tuple): return tuple(convert_arg(a) for a in arg) - if hasattr(arg, "__array__") and hasattr(arg, "__gt_dims__"): - arr = np.asarray(arg) - origin = getattr(arg, "__gt_origin__", tuple([0] * arr.ndim)) + if common.is_field(arg): + arr = arg.ndarray + origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain))) return arr, origin else: return arg diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 85030482cb..3560384eb4 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -30,6 +30,13 @@ from gt4py.next.program_processors.processor_interface import program_executor +def _create_tmp(axes, origin, shape, dtype): + if isinstance(dtype, tuple): + return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype)},)" + else: + return f"gtx.np_as_located_field({axes}, origin={origin})(np.empty({shape}, dtype=np.dtype('{dtype}')))" + + class EmbeddedDSL(codegen.TemplatedGenerator): Sym = as_fmt("{id}") SymRef = as_fmt("{id}") @@ -60,14 +67,7 @@ def ${id}(${','.join(params)}): def visit_FencilWithTemporaries(self, node, **kwargs): params = self.visit(node.params) - def np_dtype(dtype): - if isinstance(dtype, int): - return params[dtype] + ".dtype" - if isinstance(dtype, tuple): - return "np.dtype([" + ", ".join(f"('', {np_dtype(d)})" for d in dtype) + "])" - return f"np.dtype('{dtype}')" - - tmps = "\n ".join(self.visit(node.tmps, np_dtype=np_dtype)) + tmps = "\n ".join(self.visit(node.tmps)) args = ", ".join(params + [tmp.id for tmp in node.tmps]) params = ", ".join(params) fencil = self.visit(node.fencil) @@ -79,7 +79,7 @@ def np_dtype(dtype): + f"\n {node.fencil.id}({args}, **kwargs)\n" ) - def visit_Temporary(self, node, *, np_dtype, **kwargs): + def visit_Temporary(self, node, **kwargs): assert isinstance(node.domain, itir.FunCall) and node.domain.fun.id in ( "cartesian_domain", "unstructured_domain", @@ -92,8 +92,7 @@ def visit_Temporary(self, node, *, np_dtype, **kwargs): axes = ", ".join(label for label, _, _ in domain_ranges) origin = "{" + ", ".join(f"{label}: -{start}" for label, start, _ in domain_ranges) + "}" shape = "(" + ", ".join(f"{stop}-{start}" for _, start, stop in domain_ranges) + ")" - dtype = np_dtype(node.dtype) - return f"{node.id} = gtx.np_as_located_field({axes}, origin={origin})(np.empty({shape}, dtype={dtype}))" + return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" _BACKEND_NAME = "roundtrip" diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 047e5bbc5f..007a83844c 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -161,8 +161,6 @@ def from_type_hint( def from_value(value: Any) -> ts.TypeSpec: # TODO(tehrengruber): use protocol from gt4py.next.common when available # instead of importing from the embedded implementation - from gt4py.next.iterator.embedded import LocatedField - """Make a symbol node from a Python value.""" # TODO(tehrengruber): What we expect here currently is a GTCallable. Maybe # we should check for the protocol in the future? @@ -185,9 +183,9 @@ def from_value(value: Any) -> ts.TypeSpec: return candidate_type elif isinstance(value, common.Dimension): symbol_type = ts.DimensionType(dim=value) - elif isinstance(value, LocatedField): + elif common.is_field(value): dims = list(value.__gt_dims__) - dtype = from_type_hint(value.dtype.type) + dtype = from_type_hint(value.dtype.scalar_type) symbol_type = ts.FieldType(dims=dims, dtype=dtype) elif isinstance(value, tuple): # Since the elements of the tuple might be one of the special cases diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d5edfdee6a..ee0074e65f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -29,7 +29,6 @@ from gt4py.eve.extended_typing import Self from gt4py.next import common from gt4py.next.ffront import decorator -from gt4py.next.iterator import embedded from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -75,7 +74,7 @@ C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim)) ScalarValue: TypeAlias = np.int32 | np.int64 | np.float32 | np.float64 | np.generic -FieldValue: TypeAlias = gtx.Field | embedded.LocatedFieldImpl +FieldValue: TypeAlias = gtx.Field FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...] FieldViewInout: TypeAlias = FieldValue | tuple["FieldViewInout", ...] ReferenceValue: TypeAlias = ( @@ -341,7 +340,9 @@ def allocate( Useful for shifted fields, which must start off bigger than the output field in the shifted dimension. """ - sizes = extend_sizes(case.default_sizes | (sizes or {}), extend) + sizes = extend_sizes( + case.default_sizes | (sizes or {}), extend + ) # TODO: this should take into account the Domain of the allocated field arg_type = get_param_types(fieldview_prog)[name] if strategy is None: if name in ["out", RETURN]: @@ -421,8 +422,8 @@ def verify( out_comp = out or inout out_comp_str = str(out_comp) assert out_comp is not None - if hasattr(out_comp, "array"): - out_comp_str = str(out_comp.array()) + if hasattr(out_comp, "ndarray"): + out_comp_str = str(out_comp.ndarray) assert comparison(ref, out_comp), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" @@ -447,12 +448,12 @@ def verify_with_default_data( case: The test case. fieldview_prog: The field operator or program to be verified. ref: A callable which will be called with all the input arguments - of the fieldview code, after applying ``.array()`` on the fields. + of the fieldview code, after applying ``.ndarray`` on the fields. comparison: A comparison function, which will be called as ``comparison(ref, )`` and should return a boolean. """ inps, kwfields = get_default_data(case, fieldop) - ref_args = tuple(i.array() if hasattr(i, "array") else i for i in inps) + ref_args = tuple(i.ndarray if hasattr(i, "ndarray") else i for i in inps) verify( case, fieldop, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index b44cbb8181..e425457224 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -158,7 +158,7 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: b = cases.allocate(cartesian_case, testee, "b").extend({cases.IDim: (0, 2)})() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - cases.verify(cartesian_case, testee, a, b, out=out, ref=a[1:] + b[2:]) + cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:]) def test_tuples(cartesian_case): # noqa: F811 # fixtures @@ -223,7 +223,7 @@ def testee(a: cases.IJKField, b: int32) -> cases.IJKField: a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, 1)})() b = cases.allocate(cartesian_case, testee, "b")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = a.array()[1:] * b + ref = a[1:] * b cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) @@ -250,7 +250,7 @@ def testee(size: gtx.IndexType, out: gtx.Field[[IDim], gtx.IndexType]): testee, size, out=out, - ref=np.full_like(out.array(), size, dtype=gtx.IndexType), + ref=np.full_like(out, size, dtype=gtx.IndexType), ) @@ -410,7 +410,7 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD comparison=lambda out, ref: np.all(out == ref), ) - assert np.allclose(out.array(), ref) + assert np.allclose(out, ref) def test_nested_tuple_return(cartesian_case): @@ -524,7 +524,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I return 3 * a[0][0] + a[0][1] + a[1] cases.verify_with_default_data( - cartesian_case, testee, ref=lambda a: 3 * a[0][0].array() + a[0][1].array() + a[1].array() + cartesian_case, testee, ref=lambda a: 3 * a[0][0] + a[0][1] + a[1] ) @@ -695,9 +695,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): cartesian_case, testee, ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)), - comparison=lambda ref, out: np.all(out[0].array() == ref[0]) - and np.all(out[1][0].array() == ref[1][0]) - and np.all(out[1][1].array() == ref[1][1]), + comparison=lambda ref, out: np.all(out[0] == ref[0]) + and np.all(out[1][0] == ref[1][0]) + and np.all(out[1][1] == ref[1][1]), ) @@ -754,7 +754,7 @@ def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cas a = cases.allocate(cartesian_case, program_bound_args, "a")() scalar = int32(1) - ref = a.array() + a.array() + 1 + ref = a + a + 1 out = cases.allocate(cartesian_case, program_bound_args, "out")() prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) @@ -773,9 +773,7 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - cases.verify( - cartesian_case, program_domain, a, out, inout=out.array()[1:9], ref=a.array()[1:9] * 2 - ) + cases.verify(cartesian_case, program_domain, a, out, inout=out[1:9], ref=a[1:9] * 2) def test_domain_input_bounds(cartesian_case): @@ -809,8 +807,8 @@ def program_domain( out, lower_i, upper_i, - inout=out.array()[lower_i : int(upper_i / 2)], - ref=inp.array()[lower_i : int(upper_i / 2)] * 2, + inout=out[lower_i : int(upper_i / 2)], + ref=inp[lower_i : int(upper_i / 2)] * 2, ) @@ -851,8 +849,8 @@ def program_domain( upper_i, lower_j, upper_j, - inout=out.array()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j], - ref=a.array()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2, + inout=out[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j], + ref=a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2, ) @@ -888,7 +886,7 @@ def program_domain_tuple( out0, out1, inout=(out0[1:9, 4:6], out1[1:9, 4:6]), - ref=(inp0.array()[1:9, 4:6] + inp1.array()[1:9, 4:6], inp1.array()[1:9, 4:6]), + ref=(inp0[1:9, 4:6] + inp1[1:9, 4:6], inp1[1:9, 4:6]), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 5f19311a32..de694eefcd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,6 +46,8 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): + if unstructured_case.backend == dace_iterator.run_dace_iterator: + pytest.xfail("Not supported in DaCe backend: reductions") if unstructured_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: pytest.xfail("`maxover` broken in gtfn, see #1289.") @@ -58,11 +60,14 @@ def testee(edge_f: cases.EField) -> cases.VField: out = cases.allocate(unstructured_case, testee, cases.RETURN)() v2e_table = unstructured_case.offset_provider["V2E"].table - ref = np.max(inp[v2e_table], axis=1) + ref = np.max(inp.ndarray[v2e_table], axis=1) cases.verify(unstructured_case, testee, inp, ref=ref, out=out) def test_minover_execution(unstructured_case): + if unstructured_case.backend == dace_iterator.run_dace_iterator: + pytest.xfail("Not supported in DaCe backend: reductions") + @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 23afc707b0..e8de3d9264 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -100,7 +100,7 @@ def mod_fieldop(inp1: cases.IField) -> cases.IField: inp1 = gtx.np_as_located_field(IDim)(np.asarray(range(10), dtype=int32) - 5) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() - cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1.array() % 2) + cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) def test_bit_xor(cartesian_case): @@ -117,7 +117,7 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie cases.ConstInitializer(bool_field) )() out = cases.allocate(cartesian_case, binary_xor, cases.RETURN)() - cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1.array() ^ inp2.array()) + cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2) def test_bit_and(cartesian_case): @@ -134,7 +134,7 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: cases.ConstInitializer(bool_field) )() out = cases.allocate(cartesian_case, bit_and, cases.RETURN)() - cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1.array() & inp2.array()) + cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2) def test_bit_or(cartesian_case): @@ -151,7 +151,7 @@ def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: cases.ConstInitializer(bool_field) )() out = cases.allocate(cartesian_case, bit_or, cases.RETURN)() - cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1.array() | inp2.array()) + cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1 | inp2) # Unary builtins @@ -176,7 +176,7 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: cases.ConstInitializer(bool_field) )() out = cases.allocate(cartesian_case, tilde_fieldop, cases.RETURN)() - cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1.array()) + cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1) def test_unary_not(cartesian_case): @@ -190,7 +190,7 @@ def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: cases.ConstInitializer(bool_field) )() out = cases.allocate(cartesian_case, not_fieldop, cases.RETURN)() - cases.verify(cartesian_case, not_fieldop, inp1, out=out, ref=~inp1.array()) + cases.verify(cartesian_case, not_fieldop, inp1, out=out, ref=~inp1) # Trig builtins diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index f12e10fcab..d7c50e83f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -75,8 +75,8 @@ def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatFie shift_by_one_program, in_field, out_field, - inout=out_field.array()[:-1], - ref=in_field.array()[1:-1], + inout=out_field[:-1], + ref=in_field[1:-1], ) @@ -184,7 +184,7 @@ def prog( cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={}) - assert np.allclose((a.array()[1:], b.array()[1:]), (out_a.array()[1:], out_b.array()[1:])) + assert np.allclose((a[1:], b[1:]), (out_a[1:], out_b[1:])) assert out_a[0] == 0 and out_b[0] == 0 diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index f44b662f22..17a1ea11cb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -105,7 +105,7 @@ def simple_if( condition1, condition2, out=out, - ref=(a if condition1 else b).array() + (0 if condition2 else 1), + ref=(a if condition1 else b) + (0 if condition2 else 1), ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 673a989122..13fcf3b87f 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -171,6 +171,10 @@ def arithmetic_and_logical_test_data(): @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor + if program_processor == run_dace_iterator: + pytest.xfail( + "Not supported in DaCe backend: argument types are not propagated for ITIR tests" + ) inps = asfield(*asarray(*inputs)) out = asfield((np.zeros_like(*asarray(expected))))[0] @@ -203,6 +207,10 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor + if program_processor == run_dace_iterator: + pytest.xfail( + "Not supported in DaCe backend: argument types are not propagated for ITIR tests" + ) if builtin_name == "gamma": # numpy has no gamma function diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index f4aa217eaf..d20ec2ee3d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -53,5 +53,5 @@ def test_conditional_w_tuple(program_processor): offset_provider={}, ) if validate: - assert np.all(out[np.asarray(inp) == 0] == 3.0) - assert np.all(out[np.asarray(inp) == 1] == 7.0) + assert np.all(out.ndarray[np.asarray(inp) == 0] == 3.0) + assert np.all(out.ndarray[np.asarray(inp) == 1] == 7.0) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index f01f021434..a22fef0d49 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -74,9 +74,9 @@ def test_simple_indirection(program_processor): cond = gtx.np_as_located_field(IDim)(rng.normal(size=shape)) out = gtx.np_as_located_field(IDim)(np.zeros(shape, dtype=inp.dtype)) - ref = np.zeros(shape) + ref = np.zeros(shape, dtype=inp.dtype) for i in range(shape[0]): - ref[i] = inp[i + 1 - 1] if cond[i] < 0 else inp[i + 1 + 1] + ref[i] = inp.ndarray[i + 1 - 1] if cond[i] < 0.0 else inp.ndarray[i + 1 + 1] run_processor( conditional_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))], diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 2f7808b30e..2076cdd864 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -59,6 +59,10 @@ def test_single_argument(program_processor, dom): def test_2_arguments(program_processor, dom): program_processor, validate = program_processor + if program_processor == run_dace_iterator: + pytest.xfail( + "Not supported in DaCe backend: argument types are not propagated for ITIR tests" + ) @fundef def fun(inp0, inp1): @@ -71,7 +75,7 @@ def fun(inp0, inp1): run_processor(fun[dom], program_processor, inp0, inp1, out=out, offset_provider={}) if validate: - assert np.allclose(inp0.array() + inp1.array(), out) + assert np.allclose(inp0 + inp1, out) def test_lambda_domain(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index 9d901410aa..e0460b67b1 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -38,9 +38,9 @@ def test_scan_in_stencil(program_processor, lift_mode): out = gtx.np_as_located_field(IDim, KDim)(np.zeros((isize, ksize))) reference = np.zeros((isize, ksize - 1)) - reference[:, 0] = inp[:, 0] + inp[:, 1] + reference[:, 0] = inp.ndarray[:, 0] + inp.ndarray[:, 1] for k in range(1, ksize - 1): - reference[:, k] = reference[:, k - 1] + inp[:, k] + inp[:, k + 1] + reference[:, k] = reference[:, k - 1] + inp.ndarray[:, k] + inp.ndarray[:, k + 1] @fundef def sum(state, k, kp): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 3f7a98f7bc..5a6ffe2891 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -94,58 +94,6 @@ def tuple_of_tuple_output2(inp1, inp2, inp3, inp4): return make_tuple(deref(inp1), deref(inp2)), make_tuple(deref(inp3), deref(inp4)) -@pytest.mark.parametrize( - "stencil", - [tuple_of_tuple_output1, tuple_of_tuple_output2], -) -def test_tuple_of_field_of_tuple_output(program_processor_no_gtfn_exec, stencil): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - - shape = [5, 7, 9] - rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp3 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp4 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - - out_np1 = np.zeros(shape, dtype="f8, f8") - out1 = gtx.np_as_located_field(IDim, JDim, KDim)(out_np1) - out_np2 = np.zeros(shape, dtype="f8, f8") - out2 = gtx.np_as_located_field(IDim, JDim, KDim)(out_np2) - out = (out1, out2) - - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } - run_processor( - stencil[dom], - program_processor, - inp1, - inp2, - inp3, - inp4, - out=out, - offset_provider={}, - ) - if validate: - assert np.allclose(inp1, out_np1[:]["f0"]) - assert np.allclose(inp2, out_np1[:]["f1"]) - assert np.allclose(inp3, out_np2[:]["f0"]) - assert np.allclose(inp4, out_np2[:]["f1"]) - - def test_tuple_of_tuple_of_field_output(program_processor): program_processor, validate = program_processor if program_processor == run_dace_iterator: @@ -203,38 +151,6 @@ def stencil(inp1, inp2, inp3, inp4): assert np.allclose(inp4, out[1][1]) -@pytest.mark.parametrize( - "stencil", - [tuple_output1, tuple_output2], -) -def test_field_of_tuple_output(program_processor_no_gtfn_exec, stencil): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - - shape = [5, 7, 9] - rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])), - ) - - out_np = np.zeros(shape, dtype="f8, f8") - out = gtx.np_as_located_field(IDim, JDim, KDim)(out_np) - - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } - run_processor(stencil[dom], program_processor, inp1, inp2, out=out, offset_provider={}) - if validate: - assert np.allclose(inp1, out_np[:]["f0"]) - assert np.allclose(inp2, out_np[:]["f1"]) - - @pytest.mark.parametrize( "stencil", [tuple_output1, tuple_output2], @@ -344,6 +260,7 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): assert np.allclose(inp3, out3) +@pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") @pytest.mark.parametrize( "stencil", [tuple_output1, tuple_output2], @@ -408,35 +325,7 @@ def test_tuple_field_input(program_processor): assert np.allclose(np.asarray(inp1) + np.asarray(inp2), out) -def test_field_of_tuple_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - - shape = [5, 7, 9] - rng = np.random.default_rng() - - inp1 = rng.normal(size=(shape[0], shape[1], shape[2])) - inp2 = rng.normal(size=(shape[0], shape[1], shape[2])) - inp = np.zeros(shape, dtype="f8, f8") - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - inp[i, j, k] = (inp1[i, j, k], inp2[i, j, k]) - - inp = gtx.np_as_located_field(IDim, JDim, KDim)(inp) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) - - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } - run_processor(tuple_input[dom], program_processor, inp, out=out, offset_provider={}) - if validate: - assert np.allclose(np.asarray(inp1) + np.asarray(inp2), out) - - +@pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") def test_field_of_extra_dim_input(program_processor_no_gtfn_exec): program_processor, validate = program_processor_no_gtfn_exec if program_processor == run_dace_iterator: @@ -473,41 +362,6 @@ def tuple_tuple_input(inp): ) -def test_tuple_of_field_of_tuple_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - - shape = [5, 7, 9] - rng = np.random.default_rng() - - inp1 = rng.normal(size=(shape[0], shape[1], shape[2])) - inp2 = rng.normal(size=(shape[0], shape[1], shape[2])) - inp = np.zeros(shape, dtype="f8, f8") - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - inp[i, j, k] = (inp1[i, j, k], inp2[i, j, k]) - - inp = gtx.np_as_located_field(IDim, JDim, KDim)(inp) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) - - dom = { - IDim: range(0, shape[0]), - JDim: range(0, shape[1]), - KDim: range(0, shape[2]), - } - run_processor( - tuple_tuple_input[dom], - program_processor, - (inp, inp), - out=out, - offset_provider={}, - ) - if validate: - assert np.allclose(2.0 * (np.asarray(inp1) + np.asarray(inp2)), out) - - def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor if program_processor == run_dace_iterator: @@ -549,6 +403,7 @@ def test_tuple_of_tuple_of_field_input(program_processor): ) +@pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") def test_field_of_2_extra_dim_input(program_processor_no_gtfn_exec): program_processor, validate = program_processor_no_gtfn_exec if program_processor == run_dace_iterator: diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index e3cecfa88f..3eaefa76de 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -41,30 +41,30 @@ def mixed_args( def test_allocate_default_unique(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, mixed_args, "a")() - assert np.min(a.array()) == 0 - assert np.max(a.array()) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 + assert np.min(a) == 0 + assert np.max(a) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 b = cases.allocate(cartesian_case, mixed_args, "b")() - assert b == np.max(a.array()) + 1 + assert b == np.max(a) + 1 c = cases.allocate(cartesian_case, mixed_args, "c")() - assert np.min(c.array()) == b + 1 - assert np.max(c.array()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + assert np.min(c) == b + 1 + assert np.max(c) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 def test_allocate_return_default_zeros(cartesian_case): # noqa: F811 # fixtures a, (b, c) = cases.allocate(cartesian_case, mixed_args, cases.RETURN)() - assert np.all(a.array() == 0) - assert np.all(a.array() == b.array()) - assert np.all(b.array() == c.array()) + assert np.all(np.asarray(a) == 0) + assert np.all(np.asarray(a) == b) + assert np.all(np.asarray(b) == c) def test_allocate_const(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, mixed_args, "a").strategy(cases.ConstInitializer(42))() - assert np.all(a.array() == 42) + assert np.all(np.asarray(a) == 42) b = cases.allocate(cartesian_case, mixed_args, "b").strategy(cases.ConstInitializer(42))() assert b == 42.0 @@ -88,7 +88,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures out = cases.allocate(cartesian_case, addition, cases.RETURN)() with pytest.raises(errors.DSLError): - cases.verify(cartesian_case, addition, a, b, out=out, ref=a.array() + b.array()) + cases.verify(cartesian_case, addition, a, b, out=out, ref=a + b) @pytest.mark.parametrize("fieldview_backend", [roundtrip.executor]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 931cb813c5..5f0f273d0f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -204,7 +204,7 @@ class setup: w = gtx.np_as_located_field(Cell, KDim)( np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q_ref, w_ref = reference(z_alpha, z_beta, z_q, w) + z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) dummy = gtx.np_as_located_field(Cell, KDim)(np.zeros((cell_size, k_size), dtype=bool)) z_q_out = gtx.np_as_located_field(Cell, KDim)(np.zeros((cell_size, k_size))) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index b7af70fb71..d275a977dd 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -69,8 +69,8 @@ def test_ffront_lap(cartesian_case): lap_program, in_field, out_field, - inout=out_field.array()[1:-1, 1:-1], - ref=lap_ref(np.asarray(in_field)), + inout=out_field[1:-1, 1:-1], + ref=lap_ref(in_field.ndarray), ) in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() @@ -81,6 +81,6 @@ def test_ffront_lap(cartesian_case): laplap_program, in_field, out_field, - inout=out_field.array()[2:-2, 2:-2], - ref=lap_ref(lap_ref(np.asarray(in_field))), + inout=out_field[2:-2, 2:-2], + ref=lap_ref(lap_ref(np.asarray(in_field.ndarray))), ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 5970b9a2a9..5211f2184d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -116,6 +116,79 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): assert np.allclose(ref, out) +@fundef +def k_level_condition_lower(k_idx, k_level): + return if_(deref(k_idx) > deref(k_level), deref(shift(K, -1)(k_idx)), 0) + + +@fundef +def k_level_condition_upper(k_idx, k_level): + return if_(deref(k_idx) < deref(k_level), deref(shift(K, +1)(k_idx)), 0) + + +@fundef +def k_level_condition_upper_tuple(k_idx, k_level): + shifted_val = deref(shift(K, +1)(k_idx)) + return if_( + tuple_get(0, deref(k_idx)) < deref(k_level), + tuple_get(0, shifted_val) + tuple_get(1, shifted_val), + 0, + ) + + +@pytest.mark.parametrize( + "fun, k_level, inp_function, ref_function", + [ + ( + k_level_condition_lower, + lambda inp: 0, + lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + lambda inp: np.concatenate([[0], inp[:-1]]), + ), + ( + k_level_condition_upper, + lambda inp: inp.shape[0] - 1, + lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + lambda inp: np.concatenate([inp[1:], [0]]), + ), + ( + k_level_condition_upper_tuple, + lambda inp: inp[0].shape[0] - 1, + lambda k_size: ( + gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + ), + lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), + ), + ], +) +def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): + program_processor, validate = program_processor + + if program_processor == run_dace_iterator: + pytest.xfail("Not supported in DaCe backend: tuple arguments") + + k_size = 5 + inp = inp_function(k_size) + ref = ref_function(inp) + + out = gtx.np_as_located_field(KDim)(np.zeros((5,), dtype=np.int32)) + + run_processor( + fun[{KDim: range(0, k_size)}], + program_processor, + inp, + k_level(inp), + out=out, + offset_provider={"K": KDim}, + column_axis=KDim, + lift_mode=lift_mode, + ) + + if validate: + np.allclose(ref, out) + + @fundef def sum_scanpass(state, inp): return state + deref(inp) @@ -228,16 +301,16 @@ def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): [ ( 0, - np.asarray( - [[(0, 0), (1, 1), (3, 3), (6, 6), (10, 10), (15, 15), (21, 21)]], - dtype=np.dtype([("foo", np.float64), ("bar", np.int32)]), + ( + np.asarray([0, 1, 3, 6, 10, 15, 21], dtype=np.float64), + np.asarray([0, 1, 3, 6, 10, 15, 21], dtype=np.int32), ), ), ( 2, - np.asarray( - [[(0, 0), (0, 0), (2, 2), (5, 5), (9, 9), (14, 14), (20, 20)]], - dtype=np.dtype([("foo", np.float64), ("bar", np.int32)]), + ( + np.asarray([0, 0, 2, 5, 9, 14, 20], dtype=np.float64), + np.asarray([0, 0, 2, 5, 9, 14, 20], dtype=np.int32), ), ), ], @@ -255,7 +328,10 @@ def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): shape = [1, 7] inp0 = gtx.np_as_located_field(IDim, KDim)(np.asarray([list(range(7))], dtype=np.float64)) inp1 = gtx.np_as_located_field(IDim, KDim)(np.asarray([list(range(7))], dtype=np.int32)) - out = gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=reference.dtype)) + out = ( + gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=np.float64)), + gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=np.float32)), + ) run_processor( kdoublesum_fencil, @@ -271,8 +347,8 @@ def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): ) if validate: - for n in reference.dtype.names: - assert np.allclose(reference[n], np.asarray(out)[n]) + for ref, o in zip(reference, out): + assert np.allclose(ref, o) @fundef @@ -292,12 +368,16 @@ def sum_shifted_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor + if program_processor == run_dace_iterator: + pytest.xfail( + "Not supported in DaCe backend: argument types are not propagated for ITIR tests" + ) k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) inp1 = gtx.np_as_located_field(KDim)(np.arange(0, k_size + 1)) out = gtx.np_as_located_field(KDim)(np.zeros(k_size, dtype=inp0.dtype)) - ref = inp0 + inp1[1:] + ref = inp0.ndarray + inp1.ndarray[1:] run_processor( sum_shifted_fencil, @@ -337,7 +417,7 @@ def test_different_vertical_sizes_with_origin(program_processor): inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) inp1 = gtx.np_as_located_field(KDim, origin={KDim: 1})(np.arange(0, k_size + 1)) out = gtx.np_as_located_field(KDim)(np.zeros(k_size, dtype=np.int64)) - ref = inp0 + np.asarray(inp1)[:-1] + ref = np.asarray(inp0) + np.asarray(inp1)[:-1] run_processor( sum_fencil, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 1ae9f01b1c..ab22e2b360 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -163,8 +163,8 @@ def test_compute_zavgS(program_processor, lift_mode): ) if validate: - assert_close(-199755464.25741270, min(zavgS)) - assert_close(388241977.58389181, max(zavgS)) + assert_close(-199755464.25741270, np.min(zavgS)) + assert_close(388241977.58389181, np.max(zavgS)) run_processor( compute_zavgS_fencil, @@ -177,8 +177,8 @@ def test_compute_zavgS(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert_close(-1000788897.3202186, min(zavgS)) - assert_close(1000788897.3202186, max(zavgS)) + assert_close(-1000788897.3202186, np.min(zavgS)) + assert_close(1000788897.3202186, np.max(zavgS)) @fendef @@ -204,9 +204,7 @@ def test_compute_zavgS2(program_processor, lift_mode): pp = gtx.np_as_located_field(Vertex)(setup.input_field) - S = gtx.np_as_located_field(Edge)( - np.array([(a, b) for a, b in zip(*(setup.S_fields[0], setup.S_fields[1]))], dtype="d,d") - ) + S = tuple(gtx.np_as_located_field(Edge)(s) for s in setup.S_fields) zavgS = ( gtx.np_as_located_field(Edge)(np.zeros((setup.edges_size))), @@ -229,11 +227,11 @@ def test_compute_zavgS2(program_processor, lift_mode): ) if validate: - assert_close(-199755464.25741270, min(zavgS[0])) - assert_close(388241977.58389181, max(zavgS[0])) + assert_close(-199755464.25741270, np.min(zavgS[0])) + assert_close(388241977.58389181, np.max(zavgS[0])) - assert_close(-1000788897.3202186, min(zavgS[1])) - assert_close(1000788897.3202186, max(zavgS[1])) + assert_close(-1000788897.3202186, np.min(zavgS[1])) + assert_close(1000788897.3202186, np.max(zavgS[1])) def test_nabla(program_processor, lift_mode): @@ -274,10 +272,10 @@ def test_nabla(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) @fendef @@ -305,9 +303,7 @@ def test_nabla2(program_processor, lift_mode): sign = gtx.np_as_located_field(Vertex, V2EDim)(setup.sign_field) pp = gtx.np_as_located_field(Vertex)(setup.input_field) - S_M = gtx.np_as_located_field(Edge)( - np.array([(a, b) for a, b in zip(*(setup.S_fields[0], setup.S_fields[1]))], dtype="d,d") - ) + S_M = tuple(gtx.np_as_located_field(Edge)(s) for s in setup.S_fields) vol = gtx.np_as_located_field(Vertex)(setup.vol_field) pnabla_MXX = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) @@ -333,10 +329,10 @@ def test_nabla2(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) @fundef @@ -420,7 +416,7 @@ def test_nabla_sign(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 6735166ed8..04edf68919 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -157,7 +157,7 @@ def run(fencil): yield run if validate: - assert np.allclose(x, np.asarray(x_s)) + assert np.allclose(x, x_s) def test_tridiag(tridiag_test): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field.py deleted file mode 100644 index 81ee640b80..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field.py +++ /dev/null @@ -1,79 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import numpy as np - -import gt4py.next as gtx -from gt4py.eve.datamodels import field -from gt4py.next.iterator import embedded - - -I = gtx.Dimension("I") -J = gtx.Dimension("J") - - -def make_located_field(dtype=np.float64): - return gtx.np_as_located_field(I, J)(np.zeros((1, 1), dtype=dtype)) - - -def test_located_field_1d(): - foo = gtx.np_as_located_field(I)(np.zeros((1,))) - - foo[0] = 42 - - assert foo.__gt_dims__[0] == I - assert foo[0] == 42 - - -def test_located_field_2d(): - foo = gtx.np_as_located_field(I, J)(np.zeros((1, 1), dtype=np.float64)) - - foo[0, 0] = 42 - - assert foo.__gt_dims__[0] == I - assert foo[0, 0] == 42 - assert foo.dtype == np.float64 - - -def test_tuple_field_concept(): - tuple_of_fields = (make_located_field(), make_located_field()) - assert embedded.can_be_tuple_field(tuple_of_fields) - - field_of_tuples = make_located_field(dtype="f8,f8") - assert embedded.can_be_tuple_field(field_of_tuples) - - -def test_field_of_tuple(): - field_of_tuples = make_located_field(dtype="f8,f8") - assert isinstance(field_of_tuples, embedded.TupleField) - - -def test_tuple_of_field(): - tuple_of_fields = embedded.TupleOfFields((make_located_field(), make_located_field())) - assert isinstance(tuple_of_fields, embedded.TupleField) - - tuple_of_fields.field_setitem({I.value: 0, J.value: 0}, (42, 43)) - assert tuple_of_fields.field_getitem({I.value: 0, J.value: 0}) == (42, 43) - - -def test_tuple_of_tuple_of_field(): - tup = ( - (make_located_field(), make_located_field()), - (make_located_field(), make_located_field()), - ) - testee = embedded.TupleOfFields(tup) - assert isinstance(testee, embedded.TupleField) - - testee.field_setitem({I.value: 0, J.value: 0}, ((42, 43), (44, 45))) - assert testee.field_getitem({I.value: 0, J.value: 0}) == ((42, 43), (44, 45)) From 54732f6cd1277f5eeba3ad41ead862db2ec95799 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Sep 2023 12:31:38 +0200 Subject: [PATCH 13/67] feature[next]: Temporaries (#1271) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for automatic creation of temporaries. Right now unstructured domain sizes are fixed at code generation time, but this could easily be lifted. Cartesian domains are symbolic. Todos (will be addressed in future PRs): - Scan are not supported properly. - Many tests on Iterator level are parameterized not only on the backend, but also `LiftMode`. This PR however makes the `LiftMode` a property of the backend (as it should be), which makes the test structure awkward. Due to time constraints and the unclear interface for configuration options of the backend (still in the working in the cleanup project) this will be addressed at a later time. - Temporaries for expressions of type tuple are at least sup-optimal at this time (some extractions are just unnecessary). - A proper heuristics on whether it makes sense performance-wise to extract a temporary. Example transformation: ``` __field_operator_composed_shift_unstructured_flat(__sym_1, out, ____sym_1_size_0, __out_size_0) { out ← ( λ(inp) → ·(↑(λ(it) → ·⟪C2Eₒ, 0ₒ⟫(it)))((↑(λ(it) → ·⟪E2Vₒ, 0ₒ⟫(it)))(inp)) )(__sym_1) @ u⟨ Cell: [0, __out_size_0) ⟩; } ``` is transformed into: ``` __field_operator_composed_shift_unstructured_flat(__sym_1, out, ____sym_1_size_0, __out_size_0) { _tmp_1 = temporary(domain=u⟨ Edge: [0, 18) ⟩, dtype=int32); __field_operator_composed_shift_unstructured_flat( __sym_1, out, ____sym_1_size_0, __out_size_0, _tmp_1 ) { _tmp_1 ← (λ(it) → ·⟪E2Vₒ, 0ₒ⟫(it))(__sym_1) @ u⟨ Edge: [0, 18) ⟩; out ← (λ(_tmp_1) → ·⟪C2Eₒ, 0ₒ⟫(_tmp_1))(_tmp_1) @ u⟨ Cell: [0, __out_size_0) ⟩; }; __field_operator_composed_shift_unstructured_flat( __sym_1, out, ____sym_1_size_0, __out_size_0, _tmp_1 ); } ``` --- pyproject.toml | 4 - src/gt4py/next/iterator/pretty_printer.py | 1 + .../iterator/transforms/collect_shifts.py | 104 ---- .../transforms/common_pattern_matcher.py | 26 + .../next/iterator/transforms/global_tmps.py | 579 ++++++++++-------- .../iterator/transforms/inline_into_scan.py | 2 +- .../iterator/transforms/inline_lambdas.py | 56 +- .../next/iterator/transforms/inline_lifts.py | 75 ++- .../next/iterator/transforms/pass_manager.py | 23 +- .../next/iterator/transforms/popup_tmps.py | 184 ------ .../next/iterator/transforms/remap_symbols.py | 4 + .../next/iterator/transforms/trace_shifts.py | 5 +- .../next/iterator/transforms/unroll_reduce.py | 13 +- src/gt4py/next/iterator/type_inference.py | 3 + .../codegens/gtfn/codegen.py | 18 +- .../codegens/gtfn/gtfn_backend.py | 9 +- .../codegens/gtfn/gtfn_module.py | 18 +- .../program_processors/runners/gtfn_cpu.py | 11 + .../ffront_tests/ffront_test_utils.py | 1 + .../ffront_tests/test_arg_call_interface.py | 3 +- .../ffront_tests/test_execution.py | 29 +- .../ffront_tests/test_gt4py_builtins.py | 6 +- .../ffront_tests/test_math_unary_builtins.py | 2 + .../ffront_tests/test_scalar_if.py | 13 +- .../test_horizontal_indirection.py | 10 +- .../test_strided_offset_provider.py | 9 +- .../ffront_tests/test_icon_like_scan.py | 15 +- .../iterator_tests/test_anton_toy.py | 8 +- .../iterator_tests/test_column_stencil.py | 9 +- .../iterator_tests/test_fvm_nabla.py | 37 +- .../iterator_tests/test_hdiff.py | 11 +- .../iterator_tests/test_vertical_advection.py | 67 +- .../test_with_toy_connectivity.py | 6 +- tests/next_tests/unit_tests/conftest.py | 2 + .../iterator_tests/test_type_inference.py | 27 + .../transforms_tests/test_collect_shifts.py | 57 -- .../transforms_tests/test_global_tmps.py | 419 ++++++------- .../transforms_tests/test_popup_tmps.py | 318 ---------- .../transforms_tests/test_trace_shifts.py | 4 +- 39 files changed, 905 insertions(+), 1283 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/collect_shifts.py create mode 100644 src/gt4py/next/iterator/transforms/common_pattern_matcher.py delete mode 100644 src/gt4py/next/iterator/transforms/popup_tmps.py delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collect_shifts.py delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_popup_tmps.py diff --git a/pyproject.toml b/pyproject.toml index 9422dd4448..e915622857 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -321,10 +321,6 @@ module = 'gt4py.next.type_system.type_translation' ignore_errors = true module = 'gt4py.next.iterator.runtime' -[[tool.mypy.overrides]] -ignore_errors = true -module = 'gt4py.next.iterator.transforms.global_tmps' - # -- pytest -- [tool.pytest] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 2d6c0b5cae..786b91bcc5 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -16,6 +16,7 @@ Inspired by P. Yelland, “A New Approach to Optimal Code Formatting”, 2015 """ +# TODO(tehrengruber): add support for printing the types of itir.Sym, itir.Literal nodes from __future__ import annotations from collections.abc import Iterable, Sequence diff --git a/src/gt4py/next/iterator/transforms/collect_shifts.py b/src/gt4py/next/iterator/transforms/collect_shifts.py deleted file mode 100644 index f36b6024da..0000000000 --- a/src/gt4py/next/iterator/transforms/collect_shifts.py +++ /dev/null @@ -1,104 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import boltons.typeutils - -from gt4py.eve import NodeVisitor -from gt4py.next.iterator import ir - - -ALL_NEIGHBORS = boltons.typeutils.make_sentinel("ALL_NEIGHBORS") - - -class CollectShifts(NodeVisitor): - """Collects shifts applied to symbol references. - - Fills the provided `shifts` keyword argument (of type `dict[str, list[tuple]]`) - with a list of offset tuples. E.g., if there is just `deref(x)` and a - `deref(shift(a, b)(x))` in the node tree, the result will be - `{"x": [(), (a, b)]}`. - - For reductions, the special value `ALL_NEIGHBORS` is used. E.g, - `reduce(f, 0.0)(shift(V2E)(x))` will return `{"x": [(V2E, ALL_NEIGHBORS)]}`. - - Limitations: - - Nested shift calls like `deref(shift(c, d)(shift(a, b)(x)))` are not supported. - That is, all shifts must be normalized (that is, `deref(shift(a, b, c, d)(x))` - works in the given example). - - Calls to lift and scan are not supported. - """ - - @staticmethod - def _as_deref(node: ir.FunCall): - if node.fun == ir.SymRef(id="deref"): - (arg,) = node.args - if isinstance(arg, ir.SymRef): - return arg.id - - @staticmethod - def _as_shift(node: ir.Expr): - if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="shift"): - return tuple(node.args) - - @classmethod - def _as_shift_call(cls, node: ir.Expr): - if ( - isinstance(node, ir.FunCall) - and (offsets := cls._as_shift(node.fun)) - and isinstance(sym := node.args[0], ir.SymRef) - ): - return sym.id, offsets - - @classmethod - def _as_deref_shift(cls, node: ir.FunCall): - if node.fun == ir.SymRef(id="deref"): - (arg,) = node.args - if sym_and_offsets := cls._as_shift_call(arg): - return sym_and_offsets - - @staticmethod - def _as_reduce(node: ir.FunCall): - if isinstance(node.fun, ir.FunCall) and node.fun.fun == ir.SymRef(id="reduce"): - assert len(node.fun.args) == 2 - return node.args - - def visit_FunCall(self, node: ir.FunCall, *, shifts: dict[str, list[tuple]]): - if sym_id := self._as_deref(node): - # direct deref of a symbol: deref(sym) - shifts.setdefault(sym_id, []).append(()) - return - if sym_and_offsets := self._as_deref_shift(node): - # deref of a shifted symbol: deref(shift(...)(sym)) - sym, offsets = sym_and_offsets - shifts.setdefault(sym, []).append(offsets) - return - if sym_and_offsets := self._as_shift_call(node): - # just shifting: shift(...)(sym) - # required to catch ‘underefed’ shifts in reduction calls - sym, offsets = sym_and_offsets - shifts.setdefault(sym, []).append(offsets) - return - if reduction_args := self._as_reduce(node): - # reduce(..., ...)(args...) - nested_shifts = dict[str, list[tuple]]() - self.visit(reduction_args, shifts=nested_shifts) - for sym, offset_list in nested_shifts.items(): - for offsets in offset_list: - shifts.setdefault(sym, []).append(offsets + (ALL_NEIGHBORS,)) - return - - if not isinstance(node.fun, ir.SymRef) or node.fun.id in ("lift", "scan"): - raise ValueError(f"Unsupported node: {node}") - - self.generic_visit(node, shifts=shifts) diff --git a/src/gt4py/next/iterator/transforms/common_pattern_matcher.py b/src/gt4py/next/iterator/transforms/common_pattern_matcher.py new file mode 100644 index 0000000000..8df4723502 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/common_pattern_matcher.py @@ -0,0 +1,26 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from typing import TypeGuard + +from gt4py.next.iterator import ir as itir + + +def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `lift(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "lift" + ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index b7ab6c2c8b..e1b697e0bc 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -11,19 +11,26 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import copy +import dataclasses +import functools +from collections.abc import Mapping +from typing import Any, Final, Iterable, Literal, Optional, Sequence -from collections.abc import Mapping, Sequence -from typing import Any, Optional - +import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait -from gt4py.next.iterator import ir, type_inference +from gt4py.eve.utils import UIDGenerator +from gt4py.next.iterator import ir, ir_makers as im, type_inference from gt4py.next.iterator.pretty_printer import PrettyPrinter +from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.transforms.cse import extract_subexpression from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.popup_tmps import PopupTmps +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs -from gt4py.next.iterator.transforms.trace_shifts import TraceShifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs """Iterator IR extension for global temporaries. @@ -39,7 +46,7 @@ """ -AUTO_DOMAIN = ir.FunCall(fun=ir.SymRef(id="_gtmp_auto_domain"), args=[]) +AUTO_DOMAIN: Final = ir.FunCall(fun=ir.SymRef(id="_gtmp_auto_domain"), args=[]) # Iterator IR extension nodes @@ -48,7 +55,7 @@ class Temporary(ir.Node): """Iterator IR extension: declaration of a temporary buffer.""" - id: Coerced[ir.SymbolName] # noqa: A003 + id: Coerced[eve.SymbolName] # noqa: A003 domain: Optional[ir.Expr] = None dtype: Optional[Any] = None @@ -113,9 +120,78 @@ def pformat_FencilWithTemporaries( # Main implementation +def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir.FunCall: + """ + Canonicalize applied lift expressions. + + Transform lift such that the arguments to the applied lift are only symbols. + + >>> expr = im.lift(im.lambda_("a")(im.deref("a")))(im.lift("deref")("inp")) + >>> print(expr) + (↑(λ(a) → ·a))((↑deref)(inp)) + >>> print(canonicalize_applied_lift(["inp"], expr)) + (↑(λ(inp) → (λ(a) → ·a)((↑deref)(inp))))(inp) + """ + assert ( + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.FunCall) + and node.fun.fun == ir.SymRef(id="lift") + ) + stencil = node.fun.args[0] + it_args = node.args + if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): + used_closure_params = collect_symbol_refs(node) + assert not (set(used_closure_params) - set(closure_params)) + return im.lift(im.lambda_(*used_closure_params)(im.call(stencil)(*it_args)))( + *used_closure_params + ) + return node + + +def temporary_extraction_predicate(expr: ir.Node, num_occurences: int) -> bool: + """Determine if `expr` is an applied lift that should be extracted as a temporary.""" + if not is_applied_lift(expr): + return False + # do not extract when the result is a list as we can not create temporaries for + # these stencils + if isinstance(expr.annex.type.dtype, type_inference.List): + return False + stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` + used_symbols = collect_symbol_refs(stencil) + # do not extract when the stencil is capturing + if used_symbols: + return False + return True + + +def _closure_parameter_argument_mapping(closure: ir.StencilClosure): + """ + Create a mapping from the closures parameters to the closure arguments. + + E.g. for the closure `out ← (λ(param) → ...)(arg) @ u⟨ ... ⟩;` we get a mapping from `param` + to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first + arg is ignored) is returned. + """ + is_scan = isinstance(closure.stencil, ir.FunCall) and closure.stencil.fun == im.ref("scan") + + if is_scan: + stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan + return { + param.id: arg for param, arg in zip(stencil.params[1:], closure.inputs, strict=True) + } + else: + assert isinstance(closure.stencil, ir.Lambda) + return { + param.id: arg for param, arg in zip(closure.stencil.params, closure.inputs, strict=True) + } + +def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> None: + used_symbol_refs = collect_symbol_refs(expr) + assert not (set(used_symbol_refs) - {param.id for param in whitelist}) -def split_closures(node: ir.FencilDefinition) -> FencilWithTemporaries: + +def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries: """Split closures on lifted function calls and introduce new temporary buffers for return values. Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the @@ -127,48 +203,99 @@ def split_closures(node: ir.FencilDefinition) -> FencilWithTemporaries: 3. Extract lifted function class as new closures with the previously created temporary as output. The closures are processed in reverse order to properly respect the dependencies. """ - tmps: list[ir.SymRef] = [] - - def handle_arg(arg): - """Handle arguments of closure calls: extract lifted function calls. - - Lifted function calls, do: - 1. Replace the call by a new symbol ref, put this into `tmps`. - 2. Put the ‘unlifted’ function call to the stack of stencil calls that still have to be - processed. - """ - if isinstance(arg, ir.SymRef): - return arg - if ( - isinstance(arg, ir.FunCall) - and isinstance(arg.fun, ir.FunCall) - and arg.fun.fun == ir.SymRef(id="lift") - ): - assert len(arg.fun.args) == 1 - ref = ir.SymRef(id=f"_gtmp_{len(tmps)}") - tmps.append(ir.Sym(id=ref.id)) - unlifted = ir.FunCall(fun=arg.fun.args[0], args=arg.args) - stencil_stack.append((ref, unlifted)) - return ref - raise AssertionError() + uid_gen_tmps = UIDGenerator(prefix="_tmp") + + type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True) - closures = [] + tmps: list[ir.Sym] = [] + + closures: list[ir.StencilClosure] = [] for closure in reversed(node.closures): - wrapped_stencil = ir.FunCall(fun=closure.stencil, args=closure.inputs) - popped_stencil = PopupTmps().visit(wrapped_stencil) + closure_stack: list[ir.StencilClosure] = [closure] + while closure_stack: + current_closure: ir.StencilClosure = closure_stack.pop() + + if current_closure.stencil == im.ref("deref"): + closures.append(current_closure) + continue + + is_scan: bool = isinstance( + current_closure.stencil, ir.FunCall + ) and current_closure.stencil.fun == im.ref("scan") + current_closure_stencil = ( + current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan + ) - stencil_stack = [(closure.output, popped_stencil)] - domain = closure.domain - while stencil_stack: - output, call = stencil_stack.pop() - closure = ir.StencilClosure( - domain=domain, - stencil=call.fun, - output=output, - inputs=[handle_arg(arg) for arg in call.args], + stencil_body, extracted_lifts, _ = extract_subexpression( + current_closure_stencil.expr, + temporary_extraction_predicate, + uid_gen_tmps, + once_only=True, + deepest_expr_first=True, ) - closures.append(closure) - domain = AUTO_DOMAIN + + if extracted_lifts: + for tmp_sym, lift_expr in extracted_lifts.items(): + # make sure the applied lift is not capturing anything except of closure params + _ensure_expr_does_not_capture(lift_expr, current_closure_stencil.params) + + assert isinstance(lift_expr, ir.FunCall) and isinstance( + lift_expr.fun, ir.FunCall + ) + + # make sure the arguments to the applied lift are only symbols + # (otherwise we would need to canonicalize using `canonicalize_applied_lift` + # this doesn't seem to be necessary right now as we extract the lifts + # in post-order of the tree) + assert all(isinstance(arg, ir.SymRef) for arg in lift_expr.args) + + # create a mapping from the closures parameters to the closure arguments + closure_param_arg_mapping = _closure_parameter_argument_mapping(current_closure) + + stencil: ir.Node = lift_expr.fun.args[0] # usually an ir.Lambda or scan + + # allocate a new temporary + tmps.append(tmp_sym) + + # create a new closure that executes the stencil of the applied lift and + # writes the result to the newly created temporary + closure_stack.append( + ir.StencilClosure( + domain=AUTO_DOMAIN, + stencil=stencil, + output=im.ref(tmp_sym.id), + inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + ) + ) + + new_stencil: ir.Lambda | ir.FunCall + # create a new stencil where all applied lifts that have been extracted are + # replaced by references to the respective temporary + new_stencil = ir.Lambda( + params=current_closure_stencil.params + list(extracted_lifts.keys()), + expr=stencil_body, + ) + # if we are extracting from an applied scan we have to wrap the scan pass again, + # i.e. transform `λ(state, ...) → ...` into `scan(λ(state, ...) → ..., ...)` + if is_scan: + new_stencil = im.call("scan")(new_stencil, current_closure.stencil.args[1:]) # type: ignore[attr-defined] # ensure by is_scan + # inline such that let statements which are just rebinding temporaries disappear + new_stencil = InlineLambdas.apply( + new_stencil, opcount_preserving=True, force_inline_lift_args=False + ) + # we're done with the current closure, add it back to the stack for further + # extraction. + closure_stack.append( + ir.StencilClosure( + domain=current_closure.domain, + stencil=new_stencil, + output=current_closure.output, + inputs=current_closure.inputs + + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], + ) + ) + else: + closures.append(current_closure) return FencilWithTemporaries( fencil=ir.FencilDefinition( @@ -176,7 +303,7 @@ def handle_arg(arg): function_definitions=node.function_definitions, params=node.params + [ir.Sym(id=tmp.id) for tmp in tmps] - + [ir.Sym(id=AUTO_DOMAIN.fun.id)], + + [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), ), params=node.params, @@ -210,147 +337,6 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari ) -def _offset_limits( - offsets: Sequence[tuple[ir.OffsetLiteral, ...]], offset_provider: Mapping[str, gtx.Dimension] -): - offset_limits = {k: (0, 0) for k in offset_provider.keys()} - for o in offsets: - offset_sum = {k: 0 for k in offset_provider.keys()} - for k, v in zip(o[0::2], o[1::2]): - assert isinstance(v, ir.OffsetLiteral) and isinstance(v.value, int) - offset_sum[k.value] += v.value - for k, v in offset_sum.items(): - old_min, old_max = offset_limits[k] - offset_limits[k] = (min(old_min, v), max(old_max, v)) - - return {v.value: offset_limits[k] for k, v in offset_provider.items()} - - -def _named_range_with_offsets( - axis_literal: ir.AxisLiteral, - lower_bound: ir.Expr, - upper_bound: ir.Expr, - lower_offset: int, - upper_offset: int, -) -> ir.FunCall: - if lower_offset: - lower_bound = ir.FunCall( - fun=ir.SymRef(id="plus"), - args=[lower_bound, ir.Literal(value=str(lower_offset), type=ir.INTEGER_INDEX_BUILTIN)], - ) - if upper_offset: - upper_bound = ir.FunCall( - fun=ir.SymRef(id="plus"), - args=[upper_bound, ir.Literal(value=str(upper_offset), type=ir.INTEGER_INDEX_BUILTIN)], - ) - return ir.FunCall( - fun=ir.SymRef(id="named_range"), args=[axis_literal, lower_bound, upper_bound] - ) - - -def _extend_cartesian_domain( - domain: ir.FunCall, offsets: Sequence[tuple], offset_provider: Mapping[str, gtx.Dimension] -): - if not any(offsets): - return domain - assert isinstance(domain, ir.FunCall) and domain.fun == ir.SymRef(id="cartesian_domain") - assert all(isinstance(axis, gtx.Dimension) for axis in offset_provider.values()) - - offset_limits = _offset_limits(offsets, offset_provider) - - named_ranges = [] - for named_range in domain.args: - assert ( - isinstance(named_range, ir.FunCall) - and isinstance(named_range.fun, ir.SymRef) - and named_range.fun.id == "named_range" - ) - axis_literal, lower_bound, upper_bound = named_range.args - assert isinstance(axis_literal, ir.AxisLiteral) - - lower_offset, upper_offset = offset_limits.get(axis_literal.value, (0, 0)) - named_ranges.append( - _named_range_with_offsets( - axis_literal, lower_bound, upper_bound, lower_offset, upper_offset - ) - ) - - return ir.FunCall(fun=domain.fun, args=named_ranges) - - -def update_cartesian_domains( - node: FencilWithTemporaries, offset_provider: Mapping[str, Any] -) -> FencilWithTemporaries: - """Replace appearances of `AUTO_DOMAIN` by concrete domain sizes. - - Naive extent analysis, does not handle boundary conditions etc. in a smart way. - """ - closures = [] - domains = dict[str, ir.Expr]() - for closure in reversed(node.fencil.closures): - if closure.domain == AUTO_DOMAIN: - domain = domains[closure.output.id] - closure = ir.StencilClosure( - domain=domain, stencil=closure.stencil, output=closure.output, inputs=closure.inputs - ) - else: - domain = closure.domain - - closures.append(closure) - - if closure.stencil == ir.SymRef(id="deref"): - domains[closure.inputs[0].id] = domain - continue - - local_shifts = TraceShifts.apply(closure) - for param, shifts in local_shifts.items(): - domains[param] = _extend_cartesian_domain(domain, shifts, offset_provider) - - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.fencil.params[:-1], - closures=list(reversed(closures)), - ), - params=node.params, - tmps=node.tmps, - ) - - -def _location_type_from_offsets( - domain: ir.FunCall, offsets: Sequence, offset_provider: Mapping[str, Any] -): - """Derive the location type of an iterator from given offsets relative to an initial domain.""" - location = domain.args[0].args[0].value - for o in offsets: - if isinstance(o, ir.OffsetLiteral) and isinstance(o.value, str): - provider = offset_provider[o.value] - if isinstance(provider, gtx.NeighborTableOffsetProvider): - location = provider.neighbor_axis.value - return location - - -def _unstructured_domain( - axis: ir.AxisLiteral, size: int, vertical_ranges: Sequence[ir.FunCall] -) -> ir.FunCall: - """Create an unstructured domain expression.""" - return ir.FunCall( - fun=ir.SymRef(id="unstructured_domain"), - args=[ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value=axis), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.Literal(value=str(size), type=ir.INTEGER_INDEX_BUILTIN), - ], - ) - ] - + list(vertical_ranges), - ) - - def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: """Extract horizontal domain sizes from an `offset_provider`. @@ -364,7 +350,7 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL sizes[provider.origin_axis.value] = max( sizes.get(provider.origin_axis.value, 0), - provider.table.shape[0], # TODO properly expose the size + provider.table.shape[0], ) sizes[provider.neighbor_axis.value] = max( sizes.get(provider.neighbor_axis.value, 0), @@ -373,36 +359,98 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di return sizes -def _domain_ranges(closures: Sequence[ir.StencilClosure]): - """Extract all `named_ranges` from the given closures.""" - ranges = dict[str, list[ir.Expr]]() - for closure in closures: - domain = closure.domain - if isinstance(domain, ir.FunCall) and domain.fun == ir.SymRef(id="unstructured_domain"): - for arg in domain.args: - assert isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="named_range") - axis = arg.args[0].value - ranges.setdefault(axis, []).append(arg) - return ranges +@dataclasses.dataclass +class SymbolicRange: + start: ir.Expr + stop: ir.Expr + def translate(self, distance: int) -> "SymbolicRange": + return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) -def update_unstructured_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): - """Replace appearances of `AUTO_DOMAIN` by concrete domain sizes. - Note: the domain sizes are extracted from the `offset_provider` and are thus compile time! - """ +@dataclasses.dataclass +class SymbolicDomain: + grid_type: Literal["unstructured_domain", "cartesian_domain"] + ranges: dict[str, SymbolicRange] + + @classmethod + def from_expr(cls, node: ir.Node): + assert isinstance(node, ir.FunCall) and node.fun in [ + im.ref("unstructured_domain"), + im.ref("cartesian_domain"), + ] + + ranges: dict[str, SymbolicRange] = {} + for named_range in node.args: + assert ( + isinstance(named_range, ir.FunCall) + and isinstance(named_range.fun, ir.SymRef) + and named_range.fun.id == "named_range" + ) + axis_literal, lower_bound, upper_bound = named_range.args + assert isinstance(axis_literal, ir.AxisLiteral) + + ranges[axis_literal.value] = SymbolicRange(lower_bound, upper_bound) + return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above + + def as_expr(self): + return im.call(self.grid_type)( + *[ + im.call("named_range")(ir.AxisLiteral(value=d), r.start, r.stop) + for d, r in self.ranges.items() + ] + ) + + +def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain: + """Return the (set) union of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + new_domain_ranges[dim] = SymbolicRange(start, stop) + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def _group_offsets( + offset_literals: Sequence[ir.OffsetLiteral], +) -> Sequence[tuple[str, int | Literal[trace_shifts.Sentinel.ALL_NEIGHBORS]]]: + tags = [tag.value for tag in offset_literals[::2]] + offsets = [ + offset.value if isinstance(offset, ir.OffsetLiteral) else offset + for offset in offset_literals[1::2] + ] + assert all(isinstance(tag, str) for tag in tags) + assert all( + isinstance(offset, int) or offset == trace_shifts.Sentinel.ALL_NEIGHBORS + for offset in offsets + ) + return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly + + +def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - vertical_ranges = _domain_ranges(node.fencil.closures) - for k in horizontal_sizes: - vertical_ranges.pop(k, None) - closures = [] - domains = dict[str, ir.Expr]() + closures: list[ir.StencilClosure] = [] + domains = dict[str, ir.FunCall]() for closure in reversed(node.fencil.closures): if closure.domain == AUTO_DOMAIN: + # every closure with auto domain should have a single out field + assert isinstance(closure.output, ir.SymRef) domain = domains[closure.output.id] closure = ir.StencilClosure( - domain=domain, stencil=closure.stencil, output=closure.output, inputs=closure.inputs + domain=copy.deepcopy(domain), + stencil=closure.stencil, + output=closure.output, + inputs=closure.inputs, ) else: domain = closure.domain @@ -410,24 +458,50 @@ def update_unstructured_domains(node: FencilWithTemporaries, offset_provider: Ma closures.append(closure) if closure.stencil == ir.SymRef(id="deref"): - domains[closure.inputs[0].id] = domain + # all closure inputs inherit the domain + for input_arg in _tuple_constituents(closure.inputs[0]): + assert isinstance(input_arg, ir.SymRef) + assert domains.get(input_arg.id, domain) == domain + domains[input_arg.id] = domain continue - local_shifts = TraceShifts.apply(closure) - for param, shifts in local_shifts.items(): - loctypes = {_location_type_from_offsets(domain, s, offset_provider) for s in shifts} - assert len(loctypes) == 1 - loctype = loctypes.pop() - horizontal_size = horizontal_sizes[loctype] - domains[param] = _unstructured_domain( - loctype, horizontal_size, vertical_ranges.values() + local_shifts = trace_shifts.TraceShifts.apply(closure) + for param, shift_chains in local_shifts.items(): + assert isinstance(param, str) + consumed_domains: list[SymbolicDomain] = ( + [SymbolicDomain.from_expr(domains[param])] if param in domains else [] ) + for shift_chain in shift_chains: + consumed_domain = SymbolicDomain.from_expr(domain) + for offset_name, offset in _group_offsets(shift_chain): + if isinstance(offset_provider[offset_name], gtx.Dimension): + # cartesian shift + dim = offset_provider[offset_name].value + consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) + elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider): + # unstructured shift + nbt_provider = offset_provider[offset_name] + old_axis = nbt_provider.origin_axis.value + new_axis = nbt_provider.neighbor_axis.value + consumed_domain.ranges.pop(old_axis) + assert new_axis not in consumed_domain.ranges + consumed_domain.ranges[new_axis] = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + ) + else: + raise NotImplementedError + consumed_domains.append(consumed_domain) + + # compute the bounds of all consumed domains + if consumed_domains: + domains[param] = domain_union(consumed_domains).as_expr() return FencilWithTemporaries( fencil=ir.FencilDefinition( id=node.fencil.id, function_definitions=node.fencil.function_definitions, - params=node.fencil.params[:-1], + params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again closures=list(reversed(closures)), ), params=node.params, @@ -435,41 +509,46 @@ def update_unstructured_domains(node: FencilWithTemporaries, offset_provider: Ma ) +def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: + if isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple"): + for arg in node.args: + yield from _tuple_constituents(arg) + else: + yield node + + def collect_tmps_info(node: FencilWithTemporaries, *, offset_provider) -> FencilWithTemporaries: """Perform type inference for finding the types of temporaries and sets the temporary size.""" tmps = {tmp.id for tmp in node.tmps} - domains: dict[str, ir.Expr] = { - closure.output.id: closure.domain - for closure in node.fencil.closures - if closure.output.id in tmps - } + domains: dict[str, ir.Expr] = {} + for closure in node.fencil.closures: + for output_field in _tuple_constituents(closure.output): + assert isinstance(output_field, ir.SymRef) + if output_field.id not in tmps: + continue + + assert output_field.id not in domains or domains[output_field.id] == closure.domain + domains[output_field.id] = closure.domain def convert_type(dtype): if isinstance(dtype, type_inference.Primitive): return dtype.name - if isinstance(dtype, type_inference.TypeVar): - return dtype.idx - if isinstance(dtype, type_inference.List): - return convert_type(dtype.dtype) - assert isinstance(dtype, type_inference.Tuple) - dtypes = [] - while isinstance(dtype, type_inference.Tuple): - dtypes.append(convert_type(dtype.front)) - dtype = dtype.others - return tuple(dtypes) - - fencil_type = type_inference.infer(node.fencil, offset_provider=offset_provider) + elif isinstance(dtype, type_inference.Tuple): + return tuple(convert_type(el) for el in dtype) + elif isinstance(dtype, type_inference.List): + raise NotImplementedError("Temporaries with dtype list not supported.") + raise AssertionError() + + all_types = type_inference.infer_all(node.fencil, offset_provider=offset_provider) + fencil_type = all_types[id(node.fencil)] assert isinstance(fencil_type, type_inference.FencilDefinitionType) assert isinstance(fencil_type.params, type_inference.Tuple) - all_types = [] types = dict[str, ir.Expr]() - for param, dtype in zip(node.fencil.params, fencil_type.params): - assert isinstance(dtype, type_inference.Val) - all_types.append(convert_type(dtype.dtype)) + for param in node.fencil.params: if param.id in tmps: - assert param.id not in types - t = all_types[-1] - types[param.id] = all_types.index(t) if isinstance(t, int) else t + dtype = all_types[id(param)] + assert isinstance(dtype, type_inference.Val) + types[param.id] = convert_type(dtype.dtype) return FencilWithTemporaries( fencil=node.fencil, @@ -480,6 +559,9 @@ def convert_type(dtype): ) +# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be +# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore +# and hence also not extract as a temporary. class CreateGlobalTmps(NodeTranslator): """Main entry point for introducing global temporaries. @@ -490,7 +572,7 @@ def visit_FencilDefinition( self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries - res = split_closures(node) + res = split_closures(node, offset_provider=offset_provider) # Prune unreferences closure inputs introduced in the previous step res = PruneClosureInputs().visit(res) # Prune unused temporaries possibly introduced in the previous step @@ -498,9 +580,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - if all(isinstance(o, gtx.Dimension) for o in offset_provider.values()): - res = update_cartesian_domains(res, offset_provider) - else: - res = update_unstructured_domains(res, offset_provider) + res = update_domains(res, offset_provider) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index 317214e28c..fe1eae6e07 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -48,7 +48,7 @@ def _should_inline(node: ir.FunCall) -> bool: def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: - inlined = inline_lambda(node, opcount_preserving=False, force_inline_lift=True) + inlined = inline_lambda(node, opcount_preserving=False, force_inline_lift_args=True) inlined = InlineLifts().visit(inlined) return inlined diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 2e680c6da5..fc268f85e3 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -17,14 +17,18 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir +from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -def inline_lambda( +# TODO(tehrengruber): Reduce complexity of the function by removing the different options here +# and introduce a generic predicate argument for the `eligible_params` instead. +def inline_lambda( # noqa: C901 # see todo above node: ir.FunCall, opcount_preserving=False, - force_inline_lift=False, + force_inline_lift_args=False, + force_inline_trivial_lift_args=False, eligible_params: Optional[list[bool]] = None, ): assert isinstance(node.fun, ir.Lambda) @@ -43,14 +47,16 @@ def inline_lambda( ): eligible_params[i] = False - if force_inline_lift: + # inline lifts, i.e. `lift(λ(...) → ...)(...)` + if force_inline_lift_args: for i, arg in enumerate(node.args): - if ( - isinstance(arg, ir.FunCall) - and isinstance(arg.fun, ir.FunCall) - and isinstance(arg.fun.fun, ir.SymRef) - and arg.fun.fun.id == "lift" - ): + if is_applied_lift(arg): + eligible_params[i] = True + + # inline trivial lifts, i.e. `lift(λ() → 1)()` + if force_inline_trivial_lift_args: + for i, arg in enumerate(node.args): + if is_applied_lift(arg) and len(arg.args) == 0: eligible_params[i] = True if node.fun.params and not any(eligible_params): @@ -110,14 +116,24 @@ def new_name(name): class InlineLambdas(NodeTranslator): """Inline lambda calls by substituting every argument by its value.""" + PRESERVED_ANNEX_ATTRS = ("type",) + opcount_preserving: bool - force_inline_lift: bool + force_inline_lift_args: bool + + force_inline_trivial_lift_args: bool @classmethod - def apply(cls, node: ir.Node, opcount_preserving=False, force_inline_lift=False): + def apply( + cls, + node: ir.Node, + opcount_preserving=False, + force_inline_lift_args=False, + force_inline_trivial_lift_args=False, + ): """ - Inline lambda calls by substituting every arguments by its value. + Inline lambda calls by substituting every argument by its value. Examples: `(λ(x) → x)(y)` to `y` @@ -126,13 +142,20 @@ def apply(cls, node: ir.Node, opcount_preserving=False, force_inline_lift=False) `(λ(x) → x+x)(y+y)` stays as is if opcount_preserving Arguments: + node: The function call node to inline into. opcount_preserving: Preserve the number of operations, i.e. only - inline lambda call if the resulting call has the same number of - operations. + inline lambda call if the resulting call has the same number of + operations. + force_inline_lift_args: Inline all arguments that are applied lifts, i.e. + `lift(λ(...) → ...)(...)`. + force_inline_trivial_lift_args: Inline all arguments that are trivial + applied lifts, e.g. `lift(λ() → 1)()`. + """ return cls( opcount_preserving=opcount_preserving, - force_inline_lift=force_inline_lift, + force_inline_lift_args=force_inline_lift_args, + force_inline_trivial_lift_args=force_inline_trivial_lift_args, ).visit(node) def visit_FunCall(self, node: ir.FunCall): @@ -141,7 +164,8 @@ def visit_FunCall(self, node: ir.FunCall): return inline_lambda( node, opcount_preserving=self.opcount_preserving, - force_inline_lift=self.force_inline_lift, + force_inline_lift_args=self.force_inline_lift_args, + force_inline_trivial_lift_args=self.force_inline_trivial_lift_args, ) return node diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index aafb5ab276..8d62450e67 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses +import enum from collections.abc import Callable from typing import Optional @@ -104,19 +105,45 @@ def _transform_and_extract_lift_args( return (im.lift(inner_stencil)(*new_args), extracted_args) +# TODO(tehrengruber): This pass has many different options that should be written as dedicated +# passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without +# performance degradation we leave everything as one pass for now. @dataclasses.dataclass class InlineLifts(traits.VisitorWithSymbolTableTrait, NodeTranslator): """Inline lifted function calls. - Optionally a predicate function can be passed which can enable or disable inlining of specific function nodes. + Optionally a predicate function can be passed which can enable or disable inlining of specific + function nodes. """ - def __init__(self, predicate: Optional[Callable[[ir.Expr, bool], bool]] = None) -> None: - super().__init__() - if predicate is None: - self.predicate = lambda _1, _2: True - else: - self.predicate = predicate + class Flag(enum.IntEnum): + #: `shift(...)(lift(f)(args...))` -> `lift(f)(shift(...)(args)...)` + PROPAGATE_SHIFT = 1 + #: `deref(lift(f)())` -> `f()` + INLINE_TRIVIAL_DEREF_LIFT = 2 + #: `deref(lift(f)(args...))` -> `f(args...)` + INLINE_DEREF_LIFT = 2 + 4 + #: `can_deref(lift(f)(args...))` -> `and(can_deref(arg[0]), and(can_deref(arg[1]), ...))` + PROPAGATE_CAN_DEREF = 8 + #: Inline arguments to lifted stencil calls, e.g.: + #: lift(λ(a) → inner_ex(a))(lift(λ(b) → outer_ex(b))(arg)) + #: is transformed into: + #: lift(λ(b) → inner_ex(outer_ex(b)))(arg) + #: Note: This option is only needed when there is no outer `deref` by which the previous + #: branches eliminate the lift calls. This occurs for example for the `reduce` builtin + #: or when a better readable expression of a lift statement is needed during debugging. + #: Due to its complexity we might want to remove this option at some point again, + #: when we see that it is not required. + INLINE_LIFTED_ARGS = 16 + + predicate: Callable[[ir.Expr, bool], bool] = lambda _1, _2: True + + flags: int = ( + Flag.PROPAGATE_SHIFT + | Flag.INLINE_DEREF_LIFT + | Flag.PROPAGATE_CAN_DEREF + | Flag.INLINE_LIFTED_ARGS + ) def visit_FunCall( self, node: ir.FunCall, *, is_scan_pass_context=False, recurse=True, **kwargs @@ -132,8 +159,7 @@ def visit_FunCall( else node ) - if _is_shift_lift(node): - # shift(...)(lift(f)(args...)) -> lift(f)(shift(...)(args)...) + if self.flags & self.Flag.PROPAGATE_SHIFT and _is_shift_lift(node): shift = node.fun assert len(node.args) == 1 lift_call = node.args[0] @@ -143,10 +169,15 @@ def visit_FunCall( ] result = ir.FunCall(fun=lift_call.fun, args=new_args) # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall return self.visit(result, recurse=False, **kwargs) - elif node.fun == ir.SymRef(id="deref"): + elif self.flags & self.Flag.INLINE_DEREF_LIFT and node.fun == ir.SymRef(id="deref"): assert len(node.args) == 1 - if _is_lift(node.args[0]) and self.predicate(node.args[0], is_scan_pass_context): - # deref(lift(f)(args...)) -> f(args...) + is_lift = _is_lift(node.args[0]) + is_eligible = is_lift and self.predicate(node.args[0], is_scan_pass_context) + is_trivial = is_lift and len(node.args[0].args) == 0 # type: ignore[attr-defined] # mypy not smart enough + if ( + self.flags & self.Flag.INLINE_DEREF_LIFT + or (self.flags & self.Flag.INLINE_TRIVIAL_DEREF_LIFT and is_trivial) + ) and is_eligible: assert isinstance(node.args[0], ir.FunCall) and isinstance( node.args[0].fun, ir.FunCall ) @@ -157,12 +188,11 @@ def visit_FunCall( if isinstance(f, ir.Lambda): new_node = inline_lambda(new_node, opcount_preserving=True) return self.visit(new_node, **kwargs) - elif node.fun == ir.SymRef(id="can_deref"): + elif self.flags & self.Flag.PROPAGATE_CAN_DEREF and node.fun == ir.SymRef(id="can_deref"): # TODO(havogt): this `can_deref` transformation doesn't look into lifted functions, # this need to be changed to be 100% compliant assert len(node.args) == 1 if _is_lift(node.args[0]) and self.predicate(node.args[0], is_scan_pass_context): - # can_deref(lift(f)(args...)) -> and(can_deref(arg[0]), and(can_deref(arg[1]), ...)) assert isinstance(node.args[0], ir.FunCall) and isinstance( node.args[0].fun, ir.FunCall ) @@ -178,17 +208,12 @@ def visit_FunCall( args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])], ) return res - elif _is_lift(node) and len(node.args) > 0 and self.predicate(node, is_scan_pass_context): - # Inline arguments to lifted stencil calls, e.g.: - # lift(λ(a) → inner_ex(a))(lift(λ(b) → outer_ex(b))(arg)) - # is transformed into: - # lift(λ(b) → inner_ex(outer_ex(b)))(arg) - # lift(λ(a) → inner_ex(shift(...)(a)))(lift(λ(b) → outer_ex(b))(arg)) - # Note: This branch is only needed when there is no outer `deref` by which the previous - # branches eliminate the lift calls. This occurs for example for the `reduce` builtin - # or when a better readable expression of a lift statement is needed during debugging. - # Due to its complexity we might want to remove this branch at some point again, - # when we see that it is not required. + elif ( + self.flags & self.Flag.INLINE_LIFTED_ARGS + and _is_lift(node) + and len(node.args) > 0 + and self.predicate(node, is_scan_pass_context) + ): stencil = node.fun.args[0] # type: ignore[attr-defined] # node already asserted to be of type ir.FunCall eligible_lifted_args = [ _is_lift(arg) and self.predicate(arg, is_scan_pass_context) for arg in node.args diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 62251a3e43..0ff3ec25c7 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -44,9 +44,18 @@ class LiftMode(enum.Enum): def _inline_lifts(ir, lift_mode): if lift_mode == LiftMode.FORCE_INLINE: return InlineLifts().visit(ir) - if lift_mode == LiftMode.SIMPLE_HEURISTIC: + elif lift_mode == LiftMode.SIMPLE_HEURISTIC: return InlineLifts(simple_inline_heuristic.is_eligible_for_inlining).visit(ir) - assert lift_mode == LiftMode.FORCE_TEMPORARIES + elif lift_mode == LiftMode.FORCE_TEMPORARIES: + return InlineLifts( + flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. + | InlineLifts.Flag.INLINE_LIFTED_ARGS + # needed for UnrollReduce and lift args like `(↑(λ() → constant)` + ).visit(ir) + else: + raise ValueError() + return ir @@ -54,7 +63,7 @@ def _inline_into_scan(ir, *, max_iter=10): for _ in range(10): # in case there are multiple levels of lambdas around the scan we have to do multiple iterations inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift=True) + inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) if inlined == ir: break ir = inlined @@ -84,13 +93,15 @@ def apply_common_transforms( for _ in range(10): inlined = ir - if lift_mode != LiftMode.FORCE_TEMPORARIES: - inlined = _inline_lifts(inlined, lift_mode) + inlined = _inline_lifts(inlined, lift_mode) inlined = InlineLambdas.apply( inlined, opcount_preserving=True, - force_inline_lift=(lift_mode == LiftMode.FORCE_INLINE), + force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), + # If trivial lifts are not inlined we might create temporaries for constants. In all + # other cases we want it anyway. + force_inline_trivial_lift_args=True, ) inlined = ConstantFolding.apply(inlined) # This pass is required to be in the loop such that when an `if_` call with tuple arguments diff --git a/src/gt4py/next/iterator/transforms/popup_tmps.py b/src/gt4py/next/iterator/transforms/popup_tmps.py deleted file mode 100644 index e76d70833b..0000000000 --- a/src/gt4py/next/iterator/transforms/popup_tmps.py +++ /dev/null @@ -1,184 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import dataclasses -from collections.abc import Callable -from functools import partial -from typing import Optional, Union, cast - -from gt4py.eve import NodeTranslator -from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs - - -@dataclasses.dataclass(frozen=True) -class PopupTmps(NodeTranslator): - """Transformation for “popping up” nested lifts to lambda arguments. - - In the simplest case, `(λ(x) → deref(lift(deref)(x)))(y)` is translated to - `(λ(x, tmp) → deref(tmp))(y, lift(deref)(y))` (where `tmp` is an arbitrary - new symbol name). - - Note that there are edge cases of lifts which can not be popped up; for - example, popping up of a lift call that references a closure argument - (like `lift(deref)(x)` where `x` is a closure argument) is not possible - as we can not pop the expression to be a closure input (because closures - just take unmodified fencil arguments as inputs). - """ - - # we use one UID generator per instance such that the generated ids are - # stable across multiple runs (required for caching to properly work) - uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) - - @staticmethod - def _extract_lambda( - node: ir.FunCall, - ) -> Optional[ - tuple[ir.Lambda, list[ir.Sym], bool, Callable[[ir.Lambda, list[ir.Expr]], ir.FunCall]] - ]: - """Extract the lambda function which is relevant for popping up lifts. - - Further, returns a bool indicating if the given function call was as a - lift expression and a wrapper function that undos the extraction. - - So The behavior is the following: - - For `lift(f)(args...)` it returns `(f, True, wrap)`. - - For `lift(scan(f, dir, init))(args...)` it returns `(f, True, wrap)`. - - For `lift(reduce(f, init))(args...)` it returns `(None, True, wrap)`. - - For `f(args...)` it returns `(f, False, wrap)`. - - For any other expression, it returns `None`. - - The returned `wrap` function undos the extraction in all cases; for example, - `wrap(f, args...)` returns `lift(f)(args...)` in the first case. - """ - if isinstance(node.fun, ir.FunCall) and node.fun.fun == ir.SymRef(id="lift"): - # lifted lambda call or lifted scan - assert len(node.fun.args) == 1 - fun = node.fun.args[0] - - is_scan = isinstance(fun, ir.FunCall) and fun.fun == ir.SymRef(id="scan") - is_reduce = isinstance(fun, ir.FunCall) and fun.fun == ir.SymRef(id="reduce") - if is_scan: - fun = fun.args[0] # type: ignore[attr-defined] # fun already asserted to be of type ir.FunCall - assert isinstance(fun, ir.Lambda) - params = fun.params[1:] - elif is_reduce: - fun = fun.args[0] # type: ignore[attr-defined] # fun already asserted to be of type ir.FunCall - assert isinstance(fun, ir.Lambda) - params = fun.params[1:] - else: - assert isinstance(fun, ir.Lambda) - params = fun.params - - def wrap(fun: ir.Lambda, args: list[ir.Expr]) -> ir.FunCall: - if is_scan: - assert isinstance(node.fun, ir.FunCall) and isinstance( - node.fun.args[0], ir.FunCall - ) # TODO(fthaler): first part of the assertion already checked above, however mypy does not catch it - scan_args = [cast(ir.Expr, fun)] + node.fun.args[0].args[1:] - f: Union[ir.Lambda, ir.FunCall] = ir.FunCall( - fun=ir.SymRef(id="scan"), args=scan_args - ) - elif is_reduce: - assert isinstance(node.fun, ir.FunCall) and isinstance( - node.fun.args[0], ir.FunCall - ) # TODO(fthaler): first part of the assertion already checked above, however mypy does not catch it - assert fun == node.fun.args[0].args[0], "Unexpected lift in reduction function." - f = node.fun.args[0] - else: - f = fun - return ir.FunCall(fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[f]), args=args) - - assert isinstance(fun, ir.Lambda) - return fun, params, True, wrap - if isinstance(node.fun, ir.Lambda): - # direct lambda call - - def wrap(fun: ir.Lambda, args: list[ir.Expr]) -> ir.FunCall: - return ir.FunCall(fun=fun, args=args) - - return node.fun, node.fun.params, False, wrap - - return None - - def visit_FunCall( - self, node: ir.FunCall, *, lifts: Optional[dict[ir.Expr, ir.SymRef]] = None - ) -> Union[ir.SymRef, ir.FunCall]: - if call_info := self._extract_lambda(node): - fun, params, is_lift, wrap = call_info - - nested_lifts = dict[ir.Expr, ir.SymRef]() - fun = self.visit(fun, lifts=nested_lifts) - # Note: lifts in arguments are just passed to the parent node - args = self.visit(node.args, lifts=lifts) - - if is_lift: - assert lifts is not None - - # check if the lifted expression captures symbols from the outer scope - symrefs = fun.walk_values().if_isinstance(ir.SymRef).getattr("id").to_set() - captured = ( - symrefs - - {p.id for p in fun.params} - - {n.id for n in nested_lifts.values()} - - ir.BUILTINS - ) - if captured: - # if symbols from an outer scope are captured, the lift has to - # be handled at that scope, so skip here and pass nested lifts on - lifts |= nested_lifts - return wrap(fun, args) - - # remap referenced function parameters in lift expression to passed argument values - assert len(params) == len(args) - symbol_map = {str(param.id): arg for param, arg in zip(params, args)} - remap = partial(RemapSymbolRefs().visit, symbol_map=symbol_map) - - nested_lifts = {remap(expr): ref for expr, ref in nested_lifts.items()} - if lifts: - # lifts have to be updated in place as they are passed to parent node - lifted = list(lifts.items()) - lifts.clear() - for expr, ref in lifted: - lifts[remap(expr)] = remap(ref) - - # extend parameter list of the function with popped lifts - new_params = [ir.Sym(id=p.id) for p in nested_lifts.values()] - fun = ir.Lambda(params=fun.params + new_params, expr=fun.expr) - # for the arguments, we have to resolve possible cross-references of lifts - symbol_map = {str(v.id): k for k, v in nested_lifts.items()} - new_args = [ - RemapSymbolRefs().visit(a, symbol_map=symbol_map) for a in nested_lifts.keys() - ] - - # updated function call, having lifts passed as arguments - call = wrap(fun, args + new_args) - - if not is_lift: - # if this is not a lift expression, we are done... - return call - - # ... otherwise we check if the same expression has already been - # lifted before, then we reference that one - assert lifts is not None - if (previous_ref := lifts.get(call)) is not None: - return previous_ref - - # if this is the first time we lift that expression, create a new - # symbol for it and register it so the parent node knows about it - ref = ir.SymRef(id=self.uids.sequential_id(prefix="_lift")) - lifts[call] = ref - return ref - return self.generic_visit(node, lifts=lifts) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index c2ba75e1dd..cdf3d76173 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -19,6 +19,8 @@ class RemapSymbolRefs(NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("type",) + def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): return symbol_map.get(str(node.id), node) @@ -38,6 +40,8 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] class RenameSymbols(NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("type",) + def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index b5697ca321..5c607e7df1 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -18,13 +18,14 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.collect_shifts import ALL_NEIGHBORS class Sentinel(enum.Enum): VALUE = object() TYPE = object() + ALL_NEIGHBORS = object() + @dataclasses.dataclass(frozen=True) class ShiftRecorder: @@ -150,7 +151,7 @@ def _map(f): def _neighbors(o, x): - return _deref(_shift(o, ALL_NEIGHBORS)(x)) + return _deref(_shift(o, Sentinel.ALL_NEIGHBORS)(x)) def _scan(f, forward, init): diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 85d7b26056..e3084eaba5 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,6 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: @@ -34,17 +35,9 @@ def _is_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunCall]: return isinstance(arg, itir.FunCall) and arg.fun == itir.SymRef(id="neighbors") -def _is_applied_lift(arg: itir.Expr) -> TypeGuard[itir.FunCall]: - return ( - isinstance(arg, itir.FunCall) - and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="lift") - ) - - def _is_neighbors_or_lifted_and_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunCall]: return _is_neighbors(arg) or ( - _is_applied_lift(arg) + is_applied_lift(arg) and any(_is_neighbors_or_lifted_and_neighbors(nested_arg) for nested_arg in arg.args) ) @@ -67,7 +60,7 @@ def _get_partial_offset_tag(arg: itir.FunCall) -> str: assert isinstance(offset.value, str) return offset.value else: - assert _is_applied_lift(arg) + assert is_applied_lift(arg) assert _is_list_of_funcalls(arg.args) partial_offsets = [_get_partial_offset_tag(arg) for arg in arg.args] assert all(o == partial_offsets[0] for o in partial_offsets) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index c9aa73a178..14f3e95e10 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -848,6 +848,8 @@ def visit_FunCall( if isinstance(node.fun, ir.SymRef) and node.fun.id in ir.GRAMMAR_BUILTINS: # builtins that are treated as part of the grammar are handled in `_visit_` return getattr(self, f"_visit_{node.fun.id}")(node, **kwargs) + elif isinstance(node.fun, ir.SymRef) and node.fun.id in ir.TYPEBUILTINS: + return Val(kind=Value(), dtype=Primitive(name=node.fun.id)) fun = self.visit(node.fun, **kwargs) args = Tuple.from_elems(*self.visit(node.args, **kwargs)) @@ -911,6 +913,7 @@ def visit_StencilClosure( size=stencil_param.size, # closure input and stencil param differ in `current_loc` current_loc=ANYWHERE, + # TODO(tehrengruber): Seems to break for scalars. Use `TypeVar.fresh()`? defined_loc=stencil_param.defined_loc, ), ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 5d463a02bf..8cd910e40f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -223,8 +223,24 @@ def visit_FencilDefinition( **kwargs, ) + def visit_TemporaryAllocation(self, node, **kwargs): + # TODO(tehrengruber): Revisit. We are currently converting an itir.NamedRange with + # start and stop values into an gtfn_ir.(Cartesian|Unstructured)Domain with + # size and offset values, just to here convert back in order to obtain stop values again. + # TODO(tehrengruber): Fix memory alignment. + assert node.domain.tagged_offsets.tags == node.domain.tagged_sizes.tags + tags = node.domain.tagged_offsets.tags + new_sizes = [] + for size, offset in zip(node.domain.tagged_offsets.values, node.domain.tagged_sizes.values): + new_sizes.append(gtfn_ir.BinaryExpr(op="+", lhs=size, rhs=offset)) + return self.generic_visit( + node, + tmp_sizes=self.visit(gtfn_ir.TaggedValues(tags=tags, values=new_sizes), **kwargs), + **kwargs, + ) + TemporaryAllocation = as_fmt( - "auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {domain}.sizes());" + "auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes});" ) FencilDefinition = as_mako( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py index 9cbae90864..4183f52550 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py @@ -47,18 +47,17 @@ def _lower( def generate( program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any ) -> str: - do_unroll = not ("imperative" in kwargs and kwargs["imperative"]) - if "imperative" in kwargs and kwargs["imperative"]: + if kwargs.get("imperative", False): try: gtfn_ir = _lower( program=program, enable_itir_transforms=enable_itir_transforms, - do_unroll=do_unroll, + do_unroll=False, **kwargs, ) except EveValueError: - # if we don't unroll, there may be lifts left in the itir which can't be lowered to gtfn. in this case - # case, just retry with unrolled reductions + # if we don't unroll, there may be lifts left in the itir which can't be lowered to + # gtfn. In this case, just retry with unrolled reductions. gtfn_ir = _lower( program=program, enable_itir_transforms=enable_itir_transforms, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 38ddc321db..5e24e855b5 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -15,7 +15,8 @@ from __future__ import annotations import dataclasses -from typing import Any, Final, TypeVar +import warnings +from typing import Any, Final, Optional, TypeVar import numpy as np @@ -24,6 +25,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import LiftMode from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn import gtfn_backend @@ -50,6 +52,7 @@ class GTFNTranslationStep( language_settings: languages.LanguageWithHeaderFilesSettings = cpp_interface.CPP_DEFAULT enable_itir_transforms: bool = True # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 use_imperative_backend: bool = False + lift_mode: Optional[LiftMode] = None def _process_regular_arguments( self, @@ -172,6 +175,18 @@ def __call__( inp.kwargs["offset_provider"] ) + # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added + # to the interface of all (or at least all of concern) backends, but instead should be + # configured in the backend itself (like it is here), until then we respect the argument + # here and warn the user if it differs from the one configured. + runtime_lift_mode = inp.kwargs.pop("lift_mode", None) + lift_mode = runtime_lift_mode or self.lift_mode + if runtime_lift_mode != self.lift_mode: + warnings.warn( + f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " + "overriden to be {str(runtime_lift_mode)} at runtime." + ) + # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters args_expr: list[str] = ["gridtools::fn::backend::naive{}", *regular_args_expr] @@ -185,6 +200,7 @@ def __call__( stencil_src = gtfn_backend.generate( program, enable_itir_transforms=self.enable_itir_transforms, + lift_mode=lift_mode, imperative=self.use_imperative_backend, **inp.kwargs, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn_cpu.py index 72f2637d87..31b8323474 100644 --- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py +++ b/src/gt4py/next/program_processors/runners/gtfn_cpu.py @@ -18,6 +18,7 @@ from gt4py.eve.utils import content_hash from gt4py.next import common +from gt4py.next.iterator.transforms import LiftMode from gt4py.next.otf import languages, recipes, stages, workflow from gt4py.next.otf.binding import cpp_interface, nanobind from gt4py.next.otf.compilation import cache, compiler @@ -126,3 +127,13 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: name="run_gtfn_cached", otf_workflow=workflow.CachedStep(step=run_gtfn.otf_workflow, hash_function=compilation_hash), ) # todo(ricoh): add API for converting an executor to a cached version of itself and vice versa + + +run_gtfn_with_temporaries = otf_compile_executor.OTFCompileExecutor[ + languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any +]( + name="run_gtfn_with_temporaries", + otf_workflow=run_gtfn.otf_workflow.replace( + translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES), + ), +) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 1f2e88c3b1..a8c35cc28f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -37,6 +37,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non roundtrip.executor, gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ], ids=lambda p: next_tests.get_processor_id(p), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index ade410ef23..71e31542f7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -171,9 +171,10 @@ def testee( def test_call_scan_operator_from_field_operator(cartesian_case): if cartesian_case.backend in [ - dace_iterator.run_dace_iterator, gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + dace_iterator.run_dace_iterator, ]: pytest.xfail("Calling scan from field operator not fully supported.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index e425457224..f50f16ea0f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -273,7 +273,11 @@ def testee(qc: cases.IKFloatField, scalar: float): def test_tuple_scalar_scan(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: + if cartesian_case.backend in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: pytest.xfail("Scalar tuple arguments are not supported in gtfn yet.") if cartesian_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: tuple arguments") @@ -379,8 +383,11 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]: def test_offset_field(cartesian_case): + if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: + pytest.xfail("Dynamic offsets not supported in gtfn") if cartesian_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: offset fields") + ref = np.full( (cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[KDim]), True, dtype=bool ) @@ -549,8 +556,14 @@ def simple_scan_operator(carry: float) -> float: def test_solve_triag(cartesian_case): - if cartesian_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: + if cartesian_case.backend in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: pytest.xfail("Nested `scan`s requires creating temporaries.") + if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if cartesian_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: scans") @@ -654,6 +667,9 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: def test_ternary_scan(cartesian_case): + if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") + @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 @@ -673,8 +689,11 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) def test_scan_nested_tuple_output(forward, cartesian_case): + if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if cartesian_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: tuple returns") + init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] expected = np.arange(1, 1 + k_size, 1, dtype=int32) @@ -777,7 +796,11 @@ def program_domain(a: cases.IField, out: cases.IField): def test_domain_input_bounds(cartesian_case): - if cartesian_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: + if cartesian_case.backend in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: pytest.xfail("FloorDiv not fully supported in gtfn.") lower_i = 1 diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index de694eefcd..7acc0e1447 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -48,7 +48,11 @@ def test_maxover_execution_(unstructured_case, strategy): if unstructured_case.backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: reductions") - if unstructured_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: + if unstructured_case.backend in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: pytest.xfail("`maxover` broken in gtfn, see #1289.") @gtx.field_operator diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index e8de3d9264..54374077b4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -71,6 +71,7 @@ def test_floordiv(cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, ]: pytest.xfail( "FloorDiv not yet supported." @@ -87,6 +88,7 @@ def test_mod(cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 17a1ea11cb..a49dd1fdcf 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -50,6 +50,7 @@ def test_simple_if(condition, cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -74,6 +75,7 @@ def test_simple_if_conditional(condition1, condition2, cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -114,6 +116,7 @@ def test_local_if(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -139,6 +142,7 @@ def test_temporary_if(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -167,6 +171,7 @@ def test_if_return(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -195,6 +200,7 @@ def test_if_stmt_if_branch_returns(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -220,6 +226,7 @@ def test_if_stmt_else_branch_returns(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -247,6 +254,7 @@ def test_if_stmt_both_branches_return(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -270,10 +278,11 @@ def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> c @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) -def test_nested_if_stmt_conditional(cartesian_case, condition1, condition2): +def test_nested_if_stmt_conditinal(cartesian_case, condition1, condition2): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -317,6 +326,7 @@ def test_nested_if(cartesian_case, condition): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") @@ -358,6 +368,7 @@ def test_if_without_else(cartesian_case, condition1, condition2): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, dace_iterator.run_dace_iterator, ]: pytest.xfail("If-stmts are not supported yet.") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index a22fef0d49..f4ebc596e5 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -34,8 +34,8 @@ from gt4py.next.program_processors.formatters.gtfn import ( format_sourcecode as gtfn_format_sourcecode, ) +from gt4py.next.program_processors.runners import gtfn_cpu from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative from next_tests.integration_tests.cases import IDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -59,8 +59,9 @@ def test_simple_indirection(program_processor): if program_processor in [ type_check.check, - run_gtfn, - run_gtfn_imperative, + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, gtfn_format_sourcecode, run_dace_iterator, ]: @@ -101,6 +102,9 @@ def test_direct_offset_for_indirection(program_processor): if program_processor == run_dace_iterator: pytest.xfail("Not supported in DaCe backend: shift offsets not literals") + if program_processor == gtfn_cpu.run_gtfn_with_temporaries: + pytest.xfail("Dynamic offsets not supported in temporaries pass.") + shape = [4] inp = gtx.np_as_located_field(IDim)(np.asarray(range(shape[0]), dtype=np.float64)) cond = gtx.np_as_located_field(IDim)(np.asarray([2, 1, -1, -2], dtype=np.int32)) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 2b94d93af9..7bfaa7f643 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -18,8 +18,8 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.program_processors.runners import gtfn_cpu from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative from next_tests.unit_tests.conftest import program_processor, run_processor @@ -51,7 +51,12 @@ def fencil(size, out, inp): def test_strided_offset_provider(program_processor): program_processor, validate = program_processor - if program_processor in [run_dace_iterator, run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + run_dace_iterator, + ]: pytest.xfail("gtx.StridedNeighborOffsetProvider not implemented in bindings.") LocA_size = 2 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 5f0f273d0f..2580c6ba7f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -212,7 +212,11 @@ class setup: def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): - if fieldview_backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: + if fieldview_backend in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: pytest.xfail("Needs implementation of scan projector.") if fieldview_backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: scans") @@ -230,6 +234,11 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): + if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: + pytest.xfail( + "Needs implementation of scan projector. Breaks in type inference as executed" + "again after CollapseTuple." + ) if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") if fieldview_backend == dace_iterator.run_dace_iterator: @@ -248,6 +257,8 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): + if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if fieldview_backend == dace_iterator.run_dace_iterator: pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( @@ -264,6 +275,8 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): + if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") if fieldview_backend == dace_iterator.run_dace_iterator: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 26ac49aa57..14d929e822 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -18,8 +18,8 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.program_processors.runners import gtfn_cpu from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -78,7 +78,11 @@ def naive_lap(inp): def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: from gt4py.next.iterator import transforms if lift_mode != transforms.LiftMode.FORCE_INLINE: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 5211f2184d..2446d6664f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -317,14 +317,7 @@ def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): ) def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): program_processor, validate = program_processor - if ( - program_processor == run_dace_iterator - or program_processor == run_gtfn - or program_processor == run_gtfn_imperative - or program_processor == gtfn_format_sourcecode - ): - pytest.xfail("structured dtype input/output currently unsupported") - + pytest.xfail("structured dtype input/output currently unsupported") shape = [1, 7] inp0 = gtx.np_as_located_field(IDim, KDim)(np.asarray([list(range(7))], dtype=np.float64)) inp1 = gtx.np_as_located_field(IDim, KDim)(np.asarray([list(range(7))], dtype=np.int32)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index ab22e2b360..2d35fb1e50 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -39,7 +39,7 @@ ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.iterator.transforms.pass_manager import LiftMode -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.multi_feature_tests.iterator_tests.fvm_nabla_setup import ( assert_close, @@ -138,7 +138,12 @@ def nabla( def test_compute_zavgS(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [run_dace_iterator, run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + run_dace_iterator, + ]: pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() @@ -198,7 +203,12 @@ def compute_zavgS2_fencil( def test_compute_zavgS2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [run_dace_iterator, run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + run_dace_iterator, + ]: pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() @@ -236,7 +246,12 @@ def test_compute_zavgS2(program_processor, lift_mode): def test_nabla(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [run_dace_iterator, run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + run_dace_iterator, + ]: pytest.xfail("TODO: bindings don't support Atlas tables") if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") @@ -297,7 +312,12 @@ def nabla2( def test_nabla2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [run_dace_iterator, run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + run_dace_iterator, + ]: pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() @@ -380,7 +400,12 @@ def test_nabla_sign(program_processor, lift_mode): program_processor, validate = program_processor if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("test is broken due to bad lift semantics in iterator IR") - if program_processor in [run_dace_iterator, run_gtfn, run_gtfn_imperative]: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + run_dace_iterator, + ]: pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 4291f5938a..1dfad40e48 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.cases import IDim, JDim from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( @@ -78,11 +78,16 @@ def hdiff(inp, coeff, out, x, y): def test_hdiff(hdiff_reference, program_processor_no_dace_exec, lift_mode): program_processor, validate = program_processor_no_dace_exec - if program_processor == run_gtfn or program_processor == run_gtfn_imperative: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: + # TODO(tehrengruber): check if still true from gt4py.next.iterator import transforms if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("there is an issue with temporaries that crashes the application") + pytest.xfail("Temporaries are not compatible with origins.") inp, coeff, out = hdiff_reference shape = (out.shape[0], out.shape[1]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 04edf68919..4474121876 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -22,7 +22,7 @@ from gt4py.next.program_processors.formatters.gtfn import ( format_sourcecode as gtfn_format_sourcecode, ) -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import ( @@ -119,15 +119,25 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): ) -@pytest.fixture -def tridiag_test(tridiag_reference, program_processor_no_dace_exec, lift_mode): +@pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) +def test_tridiag(fencil, tridiag_reference, program_processor_no_dace_exec, lift_mode): program_processor, validate = program_processor_no_dace_exec if ( - program_processor == run_gtfn - or program_processor == run_gtfn_imperative - or program_processor == gtfn_format_sourcecode - ) and lift_mode == LiftMode.FORCE_INLINE: - pytest.xfail("gtfn does only support lifted scans when using temporaries") + program_processor + in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + gtfn_format_sourcecode, + ] + and lift_mode == LiftMode.FORCE_INLINE + ): + pytest.skip("gtfn does only support lifted scans when using temporaries") + if ( + program_processor == gtfn_cpu.run_gtfn_with_temporaries + or lift_mode == LiftMode.FORCE_TEMPORARIES + ): + pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape as_3d_field = gtx.np_as_located_field(IDim, JDim, KDim) @@ -137,32 +147,21 @@ def tridiag_test(tridiag_reference, program_processor_no_dace_exec, lift_mode): d_s = as_3d_field(d) x_s = as_3d_field(np.zeros_like(x)) - def run(fencil): - run_processor( - fencil, - program_processor, - shape[0], - shape[1], - shape[2], - a_s, - b_s, - c_s, - d_s, - x_s, - offset_provider={}, - column_axis=KDim, - lift_mode=lift_mode, - ) - - yield run + run_processor( + fencil, + program_processor, + shape[0], + shape[1], + shape[2], + a_s, + b_s, + c_s, + d_s, + x_s, + offset_provider={}, + column_axis=KDim, + lift_mode=lift_mode, + ) if validate: assert np.allclose(x, x_s) - - -def test_tridiag(tridiag_test): - tridiag_test(fen_solve_tridiag) - - -def test_tridiag2(tridiag_test): - tridiag_test(fen_solve_tridiag2) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ded65bceaa..e781014c0c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -415,7 +415,11 @@ def shift_sparse_stencil2(inp): def test_shift_sparse_input_field2(program_processor_no_dace_exec, lift_mode): program_processor, validate = program_processor_no_dace_exec - if program_processor == gtfn_cpu.run_gtfn or program_processor == gtfn_cpu.run_gtfn_imperative: + if program_processor in [ + gtfn_cpu.run_gtfn, + gtfn_cpu.run_gtfn_imperative, + gtfn_cpu.run_gtfn_with_temporaries, + ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." ) # observed in `cache.Strategy.PERSISTENT` mode diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 7e34629073..09d58a4376 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -71,6 +71,7 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str (double_roundtrip.executor, True), (gtfn_cpu.run_gtfn, True), (gtfn_cpu.run_gtfn_imperative, True), + (gtfn_cpu.run_gtfn_with_temporaries, True), (gtfn.format_sourcecode, False), (dace_iterator.run_dace_iterator, True), ], @@ -92,6 +93,7 @@ def program_processor_no_gtfn_exec(program_processor): if ( program_processor[0] == gtfn_cpu.run_gtfn or program_processor[0] == gtfn_cpu.run_gtfn_imperative + or program_processor[0] == gtfn_cpu.run_gtfn_with_temporaries ): pytest.xfail("gtfn backend not yet supported.") return program_processor diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a3c53da5a9..1526e97d74 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -18,6 +18,33 @@ from gt4py.next.iterator import ir, ir_makers as im, type_inference as ti +def test_unsatisfiable_constraints(): + a = ir.Sym(id="a", dtype=("float32", False)) + b = ir.Sym(id="b", dtype=("int32", False)) + + testee = im.lambda_(a, b)(im.plus("a", "b")) + + # The type inference uses a set to store the constraints. Since the TypeVar indices use a + # global counter the constraint resolution order depends on previous runs of the inference. + # To avoid false positives we just ignore which way the constraints have been resolved. + # (The previous description has never been verified.) + expected_error = [ + ( + "Type inference failed: Can not satisfy constraints:\n" + " Primitive(name='int32') ≡ Primitive(name='float32')" + ), + ( + "Type inference failed: Can not satisfy constraints:\n" + " Primitive(name='float32') ≡ Primitive(name='int32')" + ), + ] + + try: + inferred = ti.infer(testee) + except ti.UnsatisfiableConstraintsError as e: + assert str(e) in expected_error + + def test_unsatisfiable_constraints(): a = ir.Sym(id="a", dtype=("float32", False)) b = ir.Sym(id="b", dtype=("int32", False)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collect_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collect_shifts.py deleted file mode 100644 index 20a784cfc5..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collect_shifts.py +++ /dev/null @@ -1,57 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.collect_shifts import ALL_NEIGHBORS, CollectShifts - - -def test_trivial(): - testee = ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], - ), - args=[ir.SymRef(id="x")], - ) - ], - ) - expected = {"x": [(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))]} - - actual = dict() - CollectShifts().visit(testee, shifts=actual) - assert actual == expected - - -def test_reduce(): - testee = ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="plus"), ir.Literal(value="0.0", type="float64")], - ), - args=[ - ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="shift"), args=[ir.OffsetLiteral(value="V2E")]), - args=[ir.SymRef(id="x")], - ) - ], - ) - - expected = {"x": [(ir.OffsetLiteral(value="V2E"), ALL_NEIGHBORS)]} - - actual = dict() - CollectShifts().visit(testee, shifts=actual) - assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 16fcd36bff..88f6ed517b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -11,17 +11,18 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import copy import gt4py.next as gtx from gt4py.eve.utils import UIDs -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, ir_makers as im from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, Temporary, collect_tmps_info, split_closures, - update_cartesian_domains, + update_domains, ) @@ -88,8 +89,8 @@ def test_split_closures(): ir.Sym(id="d"), ir.Sym(id="inp"), ir.Sym(id="out"), - ir.Sym(id="_gtmp_0"), - ir.Sym(id="_gtmp_1"), + ir.Sym(id="_tmp_1"), + ir.Sym(id="_tmp_2"), ir.Sym(id="_gtmp_auto_domain"), ], closures=[ @@ -102,7 +103,7 @@ def test_split_closures(): args=[ir.SymRef(id="foo_inp")], ), ), - output=ir.SymRef(id="_gtmp_1"), + output=ir.SymRef(id="_tmp_2"), inputs=[ir.SymRef(id="inp")], ), ir.StencilClosure( @@ -110,34 +111,104 @@ def test_split_closures(): stencil=ir.Lambda( params=[ ir.Sym(id="bar_inp"), - ir.Sym(id="_lift_1"), + ir.Sym(id="_tmp_2"), ], expr=ir.FunCall( fun=ir.SymRef(id="deref"), args=[ - ir.SymRef(id="_lift_1"), + ir.SymRef(id="_tmp_2"), ], ), ), - output=ir.SymRef(id="_gtmp_0"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_gtmp_1")], + output=ir.SymRef(id="_tmp_1"), + inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_2")], ), ir.StencilClosure( domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], + params=[ir.Sym(id="baz_inp"), ir.Sym(id="_tmp_1")], expr=ir.FunCall( fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="_lift_2")], + args=[ir.SymRef(id="_tmp_1")], ), ), output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_gtmp_0")], + inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_1")], ), ], ) - actual = split_closures(testee) - assert actual.tmps == [Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1")] + actual = split_closures(testee, offset_provider={}) + assert actual.tmps == [Temporary(id="_tmp_1"), Temporary(id="_tmp_2")] + assert actual.fencil == expected + + +def test_split_closures_lifted_scan(): + UIDs.reset_sequence() + + testee = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + closures=[ + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("a")( + im.call( + im.call("scan")( + im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), + True, + im.literal_from_value(0.0), + ) + )( + im.lift( + im.call("scan")( + im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), + False, + im.literal_from_value(0.0), + ) + )("a") + ) + ), + output=im.ref("out"), + inputs=[im.ref("inp")], + ) + ], + ) + + expected = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("inp"), im.sym("out"), im.sym("_tmp_1"), im.sym("_gtmp_auto_domain")], + closures=[ + ir.StencilClosure( + domain=AUTO_DOMAIN, + stencil=im.call("scan")( + im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), + False, + im.literal_from_value(0.0), + ), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp")], + ), + ir.StencilClosure( + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + stencil=im.lambda_("a", "_tmp_1")( + im.call( + im.call("scan")( + im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), + True, + im.literal_from_value(0.0), + ) + )("_tmp_1") + ), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], + ), + ], + ) + + actual = split_closures(testee, offset_provider={}) + assert actual.tmps == [Temporary(id="_tmp_1")] assert actual.fencil == expected @@ -147,241 +218,127 @@ def test_update_cartesian_domains(): id="f", function_definitions=[], params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ir.Sym(id="_gtmp_0"), - ir.Sym(id="_gtmp_1"), - ir.Sym(id="_gtmp_auto_domain"), + im.sym(name) + for name in ("i", "j", "k", "inp", "out", "_gtmp_0", "_gtmp_1", "_gtmp_auto_domain") ], closures=[ ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ), - output=ir.SymRef(id="_gtmp_1"), - inputs=[ir.SymRef(id="inp")], + stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), + output=im.ref("_gtmp_1"), + inputs=[im.ref("inp")], ), ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="_gtmp_0"), - inputs=[ir.SymRef(id="_gtmp_1")], + stencil=im.ref("deref"), + output=im.ref("_gtmp_0"), + inputs=[im.ref("_gtmp_1")], ), ir.StencilClosure( - domain=ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id=s), - ], + domain=im.call("cartesian_domain")( + *( + im.call("named_range")( + ir.AxisLiteral(value=a), + ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.ref(s), ) for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ], + ) ), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="I"), - ir.OffsetLiteral(value=1), - ], - ), - args=[ir.SymRef(id="_lift_2")], - ) - ], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_gtmp_0")], + stencil=im.lambda_("baz_inp", "_lift_2")(im.deref(im.shift("I", 1)("_lift_2"))), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_gtmp_0")], ), ], ), params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="inp"), - ir.Sym(id="out"), + im.sym("i"), + im.sym("j"), + im.sym("k"), + im.sym("inp"), + im.sym("out"), ], tmps=[ Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1"), ], ) - expected = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ir.Sym(id="_gtmp_0"), - ir.Sym(id="_gtmp_1"), - ], - closures=[ - ir.StencilClosure( - domain=ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.FunCall( - fun=ir.SymRef(id="plus"), - args=[ - ir.SymRef(id="i"), - ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), - ], - ), - ], - ) - ] - + [ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id=s), - ], - ) - for a, s in (("JDim", "j"), ("KDim", "k")) - ], - ), - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ), - output=ir.SymRef(id="_gtmp_1"), - inputs=[ir.SymRef(id="inp")], - ), - ir.StencilClosure( - domain=ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.FunCall( - fun=ir.SymRef(id="plus"), - args=[ - ir.SymRef(id="i"), - ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), - ], - ), - ], - ) - ] - + [ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id=s), - ], - ) - for a, s in (("JDim", "j"), ("KDim", "k")) - ], + expected = copy.deepcopy(testee) + assert expected.fencil.params.pop() == im.sym("_gtmp_auto_domain") + expected.fencil.closures[0].domain = ir.FunCall( + fun=im.ref("cartesian_domain"), + args=[ + ir.FunCall( + fun=im.ref("named_range"), + args=[ + ir.AxisLiteral(value="IDim"), + im.plus( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="_gtmp_0"), - inputs=[ir.SymRef(id="_gtmp_1")], - ), - ir.StencilClosure( - domain=ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), - args=[ - ir.FunCall( - fun=ir.SymRef(id="named_range"), - args=[ - ir.AxisLiteral(value=a), - ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id=s), - ], - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ], + im.plus(im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN)), + ], + ) + ] + + [ + ir.FunCall( + fun=im.ref("named_range"), + args=[ + ir.AxisLiteral(value=a), + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.ref(s), + ], + ) + for a, s in (("JDim", "j"), ("KDim", "k")) + ], + ) + expected.fencil.closures[1].domain = ir.FunCall( + fun=im.ref("cartesian_domain"), + args=[ + ir.FunCall( + fun=im.ref("named_range"), + args=[ + ir.AxisLiteral(value="IDim"), + im.plus( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal("1", ir.INTEGER_INDEX_BUILTIN), ), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="I"), - ir.OffsetLiteral(value=1), - ], - ), - args=[ir.SymRef(id="_lift_2")], - ) - ], - ), + im.plus( + im.ref("i"), + ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_gtmp_0")], - ), - ], - ), - params=[ - ir.Sym(id="i"), - ir.Sym(id="j"), - ir.Sym(id="k"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ], - tmps=[ - Temporary(id="_gtmp_0"), - Temporary(id="_gtmp_1"), + ], + ) + ] + + [ + ir.FunCall( + fun=im.ref("named_range"), + args=[ + ir.AxisLiteral(value=a), + ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), + im.ref(s), + ], + ) + for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_cartesian_domains(testee, {"I": gtx.Dimension("IDim")}) + actual = update_domains(testee, {"I": gtx.Dimension("IDim")}) assert actual == expected def test_collect_tmps_info(): tmp_domain = ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), + fun=im.ref("cartesian_domain"), args=[ ir.FunCall( - fun=ir.SymRef(id="named_range"), + fun=im.ref("named_range"), args=[ ir.AxisLiteral(value="IDim"), ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), ir.FunCall( - fun=ir.SymRef(id="plus"), + fun=im.ref("plus"), args=[ - ir.SymRef(id="i"), + im.ref("i"), ir.Literal(value="1", type=ir.INTEGER_INDEX_BUILTIN), ], ), @@ -390,11 +347,11 @@ def test_collect_tmps_info(): ] + [ ir.FunCall( - fun=ir.SymRef(id="named_range"), + fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id=s), + im.ref(s), ], ) for a, s in (("JDim", "j"), ("KDim", "k")) @@ -408,8 +365,8 @@ def test_collect_tmps_info(): ir.Sym(id="i"), ir.Sym(id="j"), ir.Sym(id="k"), - ir.Sym(id="inp"), - ir.Sym(id="out"), + ir.Sym(id="inp", dtype=("float64", False)), + ir.Sym(id="out", dtype=("float64", False)), ir.Sym(id="_gtmp_0"), ir.Sym(id="_gtmp_1"), ], @@ -419,29 +376,29 @@ def test_collect_tmps_info(): stencil=ir.Lambda( params=[ir.Sym(id="foo_inp")], expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], + fun=im.ref("deref"), + args=[im.ref("foo_inp")], ), ), - output=ir.SymRef(id="_gtmp_1"), - inputs=[ir.SymRef(id="inp")], + output=im.ref("_gtmp_1"), + inputs=[im.ref("inp")], ), ir.StencilClosure( domain=tmp_domain, - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="_gtmp_0"), - inputs=[ir.SymRef(id="_gtmp_1")], + stencil=im.ref("deref"), + output=im.ref("_gtmp_0"), + inputs=[im.ref("_gtmp_1")], ), ir.StencilClosure( domain=ir.FunCall( - fun=ir.SymRef(id="cartesian_domain"), + fun=im.ref("cartesian_domain"), args=[ ir.FunCall( - fun=ir.SymRef(id="named_range"), + fun=im.ref("named_range"), args=[ ir.AxisLiteral(value=a), ir.Literal(value="0", type=ir.INTEGER_INDEX_BUILTIN), - ir.SymRef(id=s), + im.ref(s), ], ) for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) @@ -450,23 +407,23 @@ def test_collect_tmps_info(): stencil=ir.Lambda( params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], expr=ir.FunCall( - fun=ir.SymRef(id="deref"), + fun=im.ref("deref"), args=[ ir.FunCall( fun=ir.FunCall( - fun=ir.SymRef(id="shift"), + fun=im.ref("shift"), args=[ ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1), ], ), - args=[ir.SymRef(id="_lift_2")], + args=[im.ref("_lift_2")], ) ], ), ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_gtmp_0")], + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_gtmp_0")], ), ], ), @@ -486,8 +443,8 @@ def test_collect_tmps_info(): fencil=testee.fencil, params=testee.params, tmps=[ - Temporary(id="_gtmp_0", domain=tmp_domain, dtype=3), - Temporary(id="_gtmp_1", domain=tmp_domain, dtype=3), + Temporary(id="_gtmp_0", domain=tmp_domain, dtype="float64"), + Temporary(id="_gtmp_1", domain=tmp_domain, dtype="float64"), ], ) actual = collect_tmps_info(testee, offset_provider={}) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_popup_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_popup_tmps.py deleted file mode 100644 index 774861d061..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_popup_tmps.py +++ /dev/null @@ -1,318 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import pytest - -from gt4py.eve.utils import UIDs -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.popup_tmps import PopupTmps - - -@pytest.fixture -def fresh_uid_sequence(): - UIDs.reset_sequence() - - -def test_trivial_single_lift(fresh_uid_sequence): - testee = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="bar_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ) - ], - ), - args=[ir.SymRef(id="bar_inp")], - ) - ], - ), - ), - args=[ir.SymRef(id="inp")], - ) - expected = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="bar_inp"), ir.Sym(id="_lift_1")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="_lift_1")], - ), - ), - args=[ - ir.SymRef(id="inp"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ) - ], - ), - args=[ir.SymRef(id="inp")], - ), - ], - ) - actual = PopupTmps().visit(testee) - assert actual == expected - - -def test_trivial_multiple_lifts(fresh_uid_sequence): - testee = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="baz_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="bar_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ) - ], - ), - args=[ir.SymRef(id="bar_inp")], - ) - ], - ), - ) - ], - ), - args=[ir.SymRef(id="baz_inp")], - ) - ], - ), - ), - args=[ir.SymRef(id="inp")], - ) - expected = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="_lift_2")], - ), - ), - args=[ - ir.SymRef(id="inp"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ - ir.Sym(id="bar_inp"), - ir.Sym(id="_lift_1"), - ], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="_lift_1"), - ], - ), - ) - ], - ), - args=[ - ir.SymRef(id="inp"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ) - ], - ), - args=[ir.SymRef(id="inp")], - ), - ], - ), - ], - ) - actual = PopupTmps().visit(testee) - assert actual == expected - - -def test_capture(fresh_uid_sequence): - testee = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="x")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="x")], - ), - ) - ], - ), - args=[], - ) - ], - ), - ), - args=[ir.SymRef(id="inp")], - ) - actual = PopupTmps().visit(testee) - assert actual == testee - - -def test_crossref(): - testee = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="x")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="x1")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x1")]), - ), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="x2")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="x2")], - ), - ) - ], - ), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="x3")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="x3")], - ), - ) - ], - ), - args=[ir.SymRef(id="x")], - ) - ], - ) - ], - ), - ), - args=[ir.SymRef(id="x")], - ) - expected = ir.FunCall( - fun=ir.Lambda( - params=[ - ir.Sym(id="x"), - ir.Sym(id="_lift_1"), - ir.Sym(id="_lift_2"), - ], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="x1")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x1")]), - ), - args=[ir.SymRef(id="_lift_2")], - ), - ), - args=[ - ir.SymRef(id="x"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="x3")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x3")]), - ) - ], - ), - args=[ir.SymRef(id="x")], - ), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="x2")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="x2")]), - ) - ], - ), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="x3")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="x3")], - ), - ) - ], - ), - args=[ir.SymRef(id="x")], - ) - ], - ), - ], - ) - actual = PopupTmps().visit(testee) - assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 0e2fa22f05..2624a17ebd 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.next.iterator import ir, ir_makers as im -from gt4py.next.iterator.transforms.trace_shifts import ALL_NEIGHBORS, TraceShifts +from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts def test_trivial(): @@ -105,7 +105,7 @@ def test_neighbors(): "inp": { ( ir.OffsetLiteral(value="O"), - ALL_NEIGHBORS, + Sentinel.ALL_NEIGHBORS, ) } } From 34516a65da9d5d0de930e2efc082a436e9da2485 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Tue, 19 Sep 2023 09:09:34 +0200 Subject: [PATCH 14/67] test[cartesian]: update dependencies and fix use of hypothesis decorators This PR tried to fix an error found in the Daily CI task after updating to hypothesis 6.82.1. The breaking change was fixed a couple of days later directly in in hypothesis, but the changes in this PR are likely to improve the quality of the code anyway. --- .pre-commit-config.yaml | 22 +++--- constraints.txt | 78 +++++++++---------- requirements-dev.txt | 78 +++++++++---------- src/gt4py/eve/extended_typing.py | 5 +- .../feature_tests/test_exec_info.py | 4 +- 5 files changed, 95 insertions(+), 92 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d70f335bef..b1092fafd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.7.0' # version from constraints.txt + rev: '23.9.1' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -97,7 +97,7 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.7.10 + - flake8-bugbear==23.9.16 - flake8-builtins==2.1.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.0 ========= + #========= FROM constraints.txt: v1.5.1 ========= ##[[[end]]] - rev: v1.5.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.7.0 + - black==23.9.1 - boltons==23.0.0 - cached-property==1.5.2 - - click==8.1.6 - - cmake==3.27.2 + - click==8.1.7 + - cmake==3.27.5 - cytoolz==0.12.2 - - deepdiff==6.3.1 - - devtools==0.11.0 + - deepdiff==6.5.0 + - devtools==0.12.2 - frozendict==2.3.8 - gridtools-cpp==2.3.1 - importlib-resources==6.0.1 - jinja2==3.1.2 - lark==1.1.7 - mako==1.2.4 - - nanobind==1.5.0 + - nanobind==1.5.2 - ninja==1.11.1 - numpy==1.24.4 - packaging==23.1 - pybind11==2.11.1 - - setuptools==68.1.0 + - setuptools==68.2.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index 35e3d9e330..b334851af1 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,14 +6,14 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.2.1 # via devtools +asttokens==2.4.0 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.12.1 # via sphinx -black==23.7.0 # via gt4py (pyproject.toml) +black==23.9.1 # via gt4py (pyproject.toml) blinker==1.6.2 # via flask boltons==23.0.0 # via gt4py (pyproject.toml) -build==0.10.0 # via pip-tools +build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests @@ -22,17 +22,17 @@ cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.2 # via gt4py (pyproject.toml) +click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.27.5 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.0 # via -r requirements-dev.in, pytest-cov +coverage==7.3.1 # via -r requirements-dev.in, pytest-cov cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.3.1 # via gt4py (pyproject.toml) -devtools==0.11.0 # via gt4py (pyproject.toml) +deepdiff==6.5.0 # via gt4py (pyproject.toml) +devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme @@ -41,11 +41,11 @@ exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.3.0 # via factory-boy +faker==19.6.1 # via factory-boy fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 # via tox, virtualenv +filelock==3.12.4 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.7.10 # via -r requirements-dev.in +flake8-bugbear==23.9.16 # via -r requirements-dev.in flake8-builtins==2.1.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -53,14 +53,14 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.2 # via dace +flask==2.3.3 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.26 # via pre-commit +hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.29 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via flask, sphinx +importlib-metadata==6.8.0 # via build, flask, sphinx importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest @@ -70,7 +70,7 @@ jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat -jupytext==1.15.0 # via -r requirements-dev.in +jupytext==1.15.2 # via -r requirements-dev.in lark==1.1.7 # via gt4py (pyproject.toml) mako==1.2.4 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins @@ -79,9 +79,9 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.0 # via -r requirements-dev.in +mypy==1.5.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.0 # via gt4py (pyproject.toml) +nanobind==1.5.2 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) @@ -94,36 +94,36 @@ pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.2.0 # via pytest, tox +pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.3.3 # via -r requirements-dev.in +pre-commit==3.4.0 # via -r requirements-dev.in psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx -pyproject-api==1.5.3 # via tox +pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.5.1 # via -r requirements-dev.in pytest-xdist==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker -pytz==2023.3 # via babel +pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.284 # via -r requirements-dev.in +rpds-py==0.10.3 # via jsonschema, referencing +ruff==0.0.290 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==6.2.1 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.2.2 # via -r requirements-dev.in +sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -136,8 +136,8 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.9.0 # via -r requirements-dev.in -traitlets==5.9.0 # via jupyter-core, nbformat +tox==4.11.3 # via -r requirements-dev.in +traitlets==5.10.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all @@ -182,14 +182,14 @@ types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.1 # via types-all +types-mock==5.1.0.2 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all types-paramiko==3.3.0.0 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.2 # via types-all +types-pillow==10.0.0.3 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all types-protobuf==4.24.0.1 # via types-all @@ -205,17 +205,17 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.1 # via types-all, types-tzlocal +types-pytz==2023.3.1.0 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.4 # via types-all +types-redis==4.6.0.6 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.1.0.0 # via types-cffi +types-setuptools==68.2.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.0.0.2 # via types-all +types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all types-tabulate==0.9.0.3 # via types-all types-termcolor==1.1.6.2 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.3 # via pre-commit, tox +virtualenv==20.24.5 # via pre-commit, tox websockets==11.0.3 # via dace werkzeug==2.3.7 # via flask -wheel==0.41.1 # via astunparse, pip-tools +wheel==0.41.2 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/requirements-dev.txt b/requirements-dev.txt index a167b2979a..d6dcc12d21 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,14 +6,14 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.2.1 # via devtools +asttokens==2.4.0 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.12.1 # via sphinx -black==23.7.0 # via gt4py (pyproject.toml) +black==23.9.1 # via gt4py (pyproject.toml) blinker==1.6.2 # via flask boltons==23.0.0 # via gt4py (pyproject.toml) -build==0.10.0 # via pip-tools +build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests @@ -22,17 +22,17 @@ cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.2 # via gt4py (pyproject.toml) +click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.27.5 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.0 # via -r requirements-dev.in, pytest-cov +coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.3.1 # via gt4py (pyproject.toml) -devtools==0.11.0 # via gt4py (pyproject.toml) +deepdiff==6.5.0 # via gt4py (pyproject.toml) +devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme @@ -41,11 +41,11 @@ exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.3.0 # via factory-boy +faker==19.6.1 # via factory-boy fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 # via tox, virtualenv +filelock==3.12.4 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.7.10 # via -r requirements-dev.in +flake8-bugbear==23.9.16 # via -r requirements-dev.in flake8-builtins==2.1.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -53,14 +53,14 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.2 # via dace +flask==2.3.3 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.26 # via pre-commit +hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.29 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via flask, sphinx +importlib-metadata==6.8.0 # via build, flask, sphinx importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest @@ -70,7 +70,7 @@ jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat -jupytext==1.15.0 # via -r requirements-dev.in +jupytext==1.15.2 # via -r requirements-dev.in lark==1.1.7 # via gt4py (pyproject.toml) mako==1.2.4 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins @@ -79,9 +79,9 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.0 # via -r requirements-dev.in +mypy==1.5.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.0 # via gt4py (pyproject.toml) +nanobind==1.5.2 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) @@ -94,36 +94,36 @@ pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.2.0 # via pytest, tox +pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.3.3 # via -r requirements-dev.in +pre-commit==3.4.0 # via -r requirements-dev.in psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx -pyproject-api==1.5.3 # via tox +pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.5.1 # via -r requirements-dev.in pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker -pytz==2023.3 # via babel +pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.284 # via -r requirements-dev.in +rpds-py==0.10.3 # via jsonschema, referencing +ruff==0.0.290 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==6.2.1 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.2.2 # via -r requirements-dev.in +sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -136,8 +136,8 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.9.0 # via -r requirements-dev.in -traitlets==5.9.0 # via jupyter-core, nbformat +tox==4.11.3 # via -r requirements-dev.in +traitlets==5.10.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all @@ -182,14 +182,14 @@ types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.1 # via types-all +types-mock==5.1.0.2 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all types-paramiko==3.3.0.0 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.2 # via types-all +types-pillow==10.0.0.3 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all types-protobuf==4.24.0.1 # via types-all @@ -205,17 +205,17 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.1 # via types-all, types-tzlocal +types-pytz==2023.3.1.0 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.4 # via types-all +types-redis==4.6.0.6 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.1.0.0 # via types-cffi +types-setuptools==68.2.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.0.0.2 # via types-all +types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all types-tabulate==0.9.0.3 # via types-all types-termcolor==1.1.6.2 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.3 # via pre-commit, tox +virtualenv==20.24.5 # via pre-commit, tox websockets==11.0.3 # via dace werkzeug==2.3.7 # via flask -wheel==0.41.1 # via astunparse, pip-tools +wheel==0.41.2 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 34829317d6..3b8373ade1 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -552,11 +552,14 @@ def is_value_hashable_typing( return type_annotation is None -def is_protocol(type_: Type) -> bool: +def _is_protocol(type_: type, /) -> bool: """Check if a type is a Protocol definition.""" return getattr(type_, "_is_protocol", False) +is_protocol = getattr(_typing_extensions, "is_protocol", _is_protocol) + + def get_partial_type_hints( obj: Union[ object, diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py index 2934e48e7a..6b8c02e41c 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py @@ -194,8 +194,8 @@ def subtest_stencil_info(self, exec_info, stencil_info, last_called_stencil=Fals else: assert stencil_info["total_run_cpp_time"] > stencil_info["run_cpp_time"] - @given(data=hyp_st.data()) @pytest.mark.parametrize("backend", ALL_BACKENDS) + @given(data=hyp_st.data()) def test_backcompatibility(self, data, backend, worker_id): # set backend as instance attribute self.backend = backend @@ -237,8 +237,8 @@ def test_backcompatibility(self, data, backend, worker_id): assert type(self.advection).__name__ not in exec_info assert type(self.diffusion).__name__ not in exec_info - @given(data=hyp_st.data()) @pytest.mark.parametrize("backend", ALL_BACKENDS) + @given(data=hyp_st.data()) def test_aggregate(self, data, backend, worker_id): # set backend as instance attribute self.backend = backend From ac6bf945d8b6e7677e3a247339b7698efc8806bd Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Sep 2023 10:47:44 +0200 Subject: [PATCH 15/67] feat[next]: extend DaCe support of reduction operator (#1332) Adding generic implementation of neighbor-reduction to DaCe backend based on map with Write-Conflict Resolution (WCR) on output memlet. This PR enables use of lambdas as reduction function. --- .../runners/dace_iterator/itir_to_sdfg.py | 73 +---- .../runners/dace_iterator/itir_to_tasklet.py | 310 +++++++++++++++--- .../runners/dace_iterator/utility.py | 77 ++++- .../ffront_tests/test_external_local_field.py | 3 - .../ffront_tests/test_gt4py_builtins.py | 7 +- .../test_with_toy_connectivity.py | 32 +- 6 files changed, 351 insertions(+), 151 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 4f93777215..56031d8555 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -32,6 +32,7 @@ is_scan, ) from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, @@ -321,7 +322,7 @@ def visit_StencilClosure( array_mapping = {**input_mapping, **conn_mapping} symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) - nsdfg_node, map_entry, map_exit = self._add_mapped_nested_sdfg( + nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, map_ranges=map_domain or {"__dummy": "0"}, @@ -584,76 +585,6 @@ def _visit_parallel_stencil_closure( return context.body, map_domain, [r.value.data for r in results] - def _add_mapped_nested_sdfg( - self, - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] - | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, - ) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - def _visit_domain( self, node: itir.FunCall, context: Context ) -> tuple[tuple[str, tuple[ValueExpr, ValueExpr]], ...]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d301c3e3cf..2e7a598d9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,6 +19,8 @@ import dace import numpy as np +from dace.transformation.dataflow import MapFusion +from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen from gt4py.next import Dimension, type_inference as next_typing @@ -29,12 +31,14 @@ from gt4py.next.type_system import type_specifications as ts from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, + unique_name, unique_var_name, ) @@ -56,6 +60,21 @@ def itir_type_as_dace_type(type_: next_typing.Type): raise NotImplementedError() +def get_reduce_identity_value(op_name_: str, type_: Any): + if op_name_ == "plus": + init_value = type_(0) + elif op_name_ == "multiplies": + init_value = type_(1) + elif op_name_ == "minimum": + init_value = type_("inf") + elif op_name_ == "maximum": + init_value = type_("-inf") + else: + raise NotImplementedError() + + return init_value + + _MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -136,6 +155,21 @@ class Context: body: dace.SDFG state: dace.SDFGState symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] + # if we encounter a reduction node, the reduction state needs to be pushed to child nodes + reduce_limit: int + reduce_wcr: Optional[str] + + def __init__( + self, + body: dace.SDFG, + state: dace.SDFGState, + symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr], + ): + self.body = body + self.state = state + self.symbol_map = symbol_map + self.reduce_limit = 0 + self.reduce_wcr = None def builtin_neighbors( @@ -167,13 +201,15 @@ def builtin_neighbors( table_name = connectivity_identifier(offset_dim) table_array = sdfg.arrays[table_name] + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__neigh_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", - ndrange={"neigh_idx": f"0:{table.max_neighbors}"}, + ndrange={index_name: f"0:{table.max_neighbors}"}, ) shift_tasklet = state.add_tasklet( "shift", - code="__result = __table[__idx, neigh_idx]", + code=f"__result = __table[__idx, {index_name}]", inputs={"__table", "__idx"}, outputs={"__result"}, ) @@ -227,7 +263,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset="neigh_idx"), + memlet=dace.Memlet(data=result_name, subset=index_name), src_conn="__result", ) @@ -349,6 +385,8 @@ def visit_Lambda( value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) symbol_map[param] = value context = Context(context_sdfg, context_state, symbol_map) + context.reduce_limit = prev_context.reduce_limit + context.reduce_wcr = prev_context.reduce_wcr self.context = context # Add input parameters as arrays @@ -395,7 +433,12 @@ def visit_Lambda( self.context.body.add_scalar(result_name, result.dtype, transient=True) result_access = self.context.state.add_access(result_name) self.context.state.add_edge( - result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]") + result.value, + None, + result_access, + None, + # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=result.dtype) else: @@ -531,15 +574,71 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr return iterator - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions - ] - args: list[ValueExpr] = [ValueExpr(iterator.field, iterator.dtype), *flat_index] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + args: list[ValueExpr] + if self.context.reduce_limit: + # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing + result_name = unique_var_name() + self.context.body.add_array( + result_name, + dtype=iterator.dtype, + shape=(self.context.reduce_limit,), + transient=True, + ) + result_access = self.context.state.add_access(result_name) + + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__deref_idx") + me, mx = self.context.state.add_map( + "deref_map", + ndrange={index_name: f"0:{self.context.reduce_limit}"}, + ) + + # if dim is not found in iterator indices, we take the neighbor index over the reduction domain + array_index = [ + f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name + for dim in sorted(iterator.dimensions) + ] + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ] + internals = [f"{arg.value.data}_v" for arg in args] + + deref_tasklet = self.context.state.add_tasklet( + name="deref", + inputs=set(internals), + outputs={"__result"}, + code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + ) + + for arg, internal in zip(args, internals): + input_memlet = create_memlet_full( + arg.value.data, self.context.body.arrays[arg.value.data] + ) + self.context.state.add_memlet_path( + arg.value, me, deref_tasklet, memlet=input_memlet, dst_conn=internal + ) + + self.context.state.add_memlet_path( + deref_tasklet, + mx, + result_access, + memlet=dace.Memlet(data=result_name, subset=index_name), + src_conn="__result", + ) + + return [ValueExpr(value=result_access, dtype=iterator.dtype)] + + else: + sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) + flat_index = [ + ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + ] + + args = [ValueExpr(iterator.field, int), *flat_index] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{', '.join(internals[1:])}]" + return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") def _split_shift_args( self, args: list[itir.Expr] @@ -626,47 +725,156 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def _visit_reduce(self, node: itir.FunCall): - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - args = self.visit(node.args) - assert len(args) == 1 - args = args[0] - assert len(args) == 1 - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - init = node.fun.args[1] - - nreduce = self.context.body.arrays[args[0].value.data].shape[0] - result_name = unique_var_name() result_access = self.context.state.add_access(result_name) - self.context.body.add_scalar(result_name, args[0].dtype, transient=True) - op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") - reduce_tasklet = self.context.state.add_tasklet( - "reduce", - code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", - inputs={"__values"}, - outputs={"__result"}, - ) - self.context.state.add_edge( - args[0].value, - None, - reduce_tasklet, - "__values", - dace.Memlet(data=args[0].value.data, subset=f"0:{nreduce}"), - ) - self.context.state.add_edge( - reduce_tasklet, - "__result", - result_access, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_access, args[0].dtype)] + + if len(node.args) == 1: + assert ( + isinstance(node.args[0], itir.FunCall) + and isinstance(node.args[0].fun, itir.SymRef) + and node.args[0].fun.id == "neighbors" + ) + args = self.visit(node.args) + assert len(args) == 1 + args = args[0] + assert len(args) == 1 + neighbors_expr = args[0] + result_dtype = neighbors_expr.dtype + assert isinstance(node.fun, itir.FunCall) + op_name = node.fun.args[0] + assert isinstance(op_name, itir.SymRef) + init = node.fun.args[1] + + nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0] + + self.context.body.add_scalar(result_name, result_dtype, transient=True) + op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") + reduce_tasklet = self.context.state.add_tasklet( + "reduce", + code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", + inputs={"__values"}, + outputs={"__result"}, + ) + self.context.state.add_edge( + args[0].value, + None, + reduce_tasklet, + "__values", + dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"), + ) + self.context.state.add_edge( + reduce_tasklet, + "__result", + result_access, + None, + dace.Memlet(data=result_name, subset="0"), + ) + else: + assert isinstance(node.fun, itir.FunCall) + assert isinstance(node.fun.args[0], itir.Lambda) + fun_node = node.fun.args[0] + + args = [] + for node_arg in node.args: + if ( + isinstance(node_arg, itir.FunCall) + and isinstance(node_arg.fun, itir.SymRef) + and node_arg.fun.id == "neighbors" + ): + expr = self.visit(node_arg) + args.append(*expr) + else: + args.append(None) + + # first visit only arguments for neighbor selection, all other arguments are none + neighbor_args = [arg for arg in args if arg] + + # check that all neighbors expression have the same range + assert ( + len( + set([self.context.body.arrays[expr.value.data].shape for expr in neighbor_args]) + ) + == 1 + ) + + nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] + nreduce_domain = {"__idx": f"0:{nreduce}"} + + result_dtype = neighbor_args[0].dtype + self.context.body.add_scalar(result_name, result_dtype, transient=True) + + assert isinstance(fun_node.expr, itir.FunCall) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + + # initialize the reduction result based on type of operation + init_value = get_reduce_identity_value(op_name.id, result_dtype) + init_state = self.context.body.add_state_before(self.context.state, "init") + init_tasklet = init_state.add_tasklet( + "init_reduce", {}, {"__out"}, f"__out = {init_value}" + ) + init_state.add_edge( + init_tasklet, + "__out", + init_state.add_access(result_name), + None, + dace.Memlet.simple(result_name, "0"), + ) + + # set reduction state to enable dereference of neighbors in input fields and to set WCR on reduce tasklet + self.context.reduce_limit = nreduce + self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( + "x", "y" + ) + + # visit child nodes for input arguments + for i, node_arg in enumerate(node.args): + if not args[i]: + args[i] = self.visit(node_arg)[0] + + lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) + + # clear context + self.context.reduce_limit = 0 + self.context.reduce_wcr = None + + # the connectivity arrays (neighbor tables) are not needed inside the reduce lambda SDFG + neighbor_tables = filter_neighbor_tables(self.offset_provider) + for conn, _ in neighbor_tables: + var = connectivity_identifier(conn) + lambda_context.body.remove_data(var) + # cleanup symbols previously used for shape and stride of connectivity arrays + p = RemoveUnusedSymbols() + p.apply_pass(lambda_context.body, {}) + + input_memlets = [ + create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + ] + output_memlet = dace.Memlet.simple(result_name, "0") + + input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} + output_mapping = {inner_outputs[0].value.data: output_memlet} + symbol_mapping = map_nested_sdfg_symbols( + self.context.body, lambda_context.body, input_mapping + ) + + nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( + self.context.state, + sdfg=lambda_context.body, + map_ranges=nreduce_domain, + inputs=input_mapping, + outputs=output_mapping, + symbol_mapping=symbol_mapping, + input_nodes={arg.value.data: arg.value for arg in args}, + output_nodes={result_name: result_access}, + ) + + # we apply map fusion only to the nested-SDFG which is generated for the reduction operator + # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG + self.context.body.apply_transformations_repeated([MapFusion], validate=False) + + return [ValueExpr(result_access, result_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 85b1445dd9..889a1ab150 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -81,10 +81,83 @@ def map_nested_sdfg_symbols( return symbol_mapping +def add_mapped_nested_sdfg( + state: dace.SDFGState, + map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], + inputs: dict[str, dace.Memlet], + outputs: dict[str, dace.Memlet], + sdfg: dace.SDFG, + symbol_mapping: dict[str, Any] | None = None, + schedule: Any = dace.dtypes.ScheduleType.Default, + unroll_map: bool = False, + location: Any = None, + debuginfo: Any = None, + input_nodes: dict[str, dace.nodes.AccessNode] | None = None, + output_nodes: dict[str, dace.nodes.AccessNode] | None = None, +) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: + if not symbol_mapping: + symbol_mapping = {sym: sym for sym in sdfg.free_symbols} + + nsdfg_node = state.add_nested_sdfg( + sdfg, + None, + set(inputs.keys()), + set(outputs.keys()), + symbol_mapping, + name=sdfg.name, + schedule=schedule, + location=location, + debuginfo=debuginfo, + ) + + map_entry, map_exit = state.add_map( + f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo + ) + + if input_nodes is None: + input_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + } + if output_nodes is None: + output_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + } + if not inputs: + state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) + for name, memlet in inputs.items(): + state.add_memlet_path( + input_nodes[memlet.data], + map_entry, + nsdfg_node, + memlet=memlet, + src_conn=None, + dst_conn=name, + propagate=True, + ) + if not outputs: + state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) + for name, memlet in outputs.items(): + state.add_memlet_path( + nsdfg_node, + map_exit, + output_nodes[memlet.data], + memlet=memlet, + src_conn=name, + dst_conn=None, + propagate=True, + ) + + return nsdfg_node, map_entry, map_exit + + _unique_id = 0 -def unique_var_name(): +def unique_name(prefix): global _unique_id _unique_id += 1 - return f"__var_{_unique_id}" + return f"{prefix}_{_unique_id}" + + +def unique_var_name(): + return unique_name("__var") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index f2c8525346..7f2b11afff 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -28,9 +28,6 @@ def test_external_local_field(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over non-field expressions") - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7acc0e1447..ee88b3764e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -101,9 +101,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_expression_in_call(unstructured_case): if unstructured_case.backend == dace_iterator.run_dace_iterator: - # -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir, - # so in addition to the skipped reason, currently itir is a lambda instead of the 'plus' operation - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") + pytest.xfail("Not supported in DaCe backend: make_const_list") @gtx.field_operator def reduce_expr(edge_f: cases.EField) -> cases.VField: @@ -124,9 +122,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_with_common_expression(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index e781014c0c..ee07372731 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -93,8 +93,8 @@ def sum_edges_to_vertices_reduce(in_edges): "stencil", [sum_edges_to_vertices, sum_edges_to_vertices_list_get_neighbors, sum_edges_to_vertices_reduce], ) -def test_sum_edges_to_vertices(program_processor_no_dace_exec, lift_mode, stencil): - program_processor, validate = program_processor_no_dace_exec +def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(sum(row) for row in v2e_arr)) @@ -116,10 +116,8 @@ def map_neighbors(in_edges): return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) -def test_map_neighbors(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: map_ builtin, neighbors, reduce") +def test_map_neighbors(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -144,9 +142,7 @@ def map_make_const_list(in_edges): def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: map_ builtin, neighbors, reduce, make_const_list" - ) + pytest.xfail("Not supported in DaCe backend: make_const_list") inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -194,10 +190,10 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) -def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2EDim)(np.asarray([[1, 2, 3, 4]] * 9, dtype=np.int32)) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -217,10 +213,10 @@ def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): assert np.allclose(out, ref) -def test_sparse_input_field_v2v(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field_v2v(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -276,10 +272,10 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") -def test_slice_twice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_slice_twice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim, V2VDim)(v2v_arr[v2v_arr]) - out = gtx.np_as_located_field(Vertex)(np.zeros([9])) + out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[v2v_arr][:, 2, 1] run_processor( From d03ef4f15c63f7abb388c646c0eca33e1749c352 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Sep 2023 12:52:56 +0200 Subject: [PATCH 16/67] test[next]: check for DaCe dependency in test execution (#1336) Expanding the pytest fixture for unit tests with markers to exclude tests based on feature support in the selected backend. In addition, a check is added to the DaCe backend so that tests are skipped if dace module is not installed. This is required for Spack build of icon4py, which uses the base installation of gt4py, where dace module is optional. --- .../ADRs/0015-Test_Exclusion_Matrices.md | 80 ++++++++++++++ docs/development/ADRs/Index.md | 4 +- pyproject.toml | 22 +++- tests/next_tests/exclusion_matrices.py | 89 +++++++++++++++ .../ffront_tests/ffront_test_utils.py | 31 +++++- .../ffront_tests/test_arg_call_interface.py | 11 +- .../ffront_tests/test_execution.py | 86 ++++----------- .../ffront_tests/test_external_local_field.py | 1 - .../ffront_tests/test_gt4py_builtins.py | 15 +-- .../test_math_builtin_execution.py | 1 - .../ffront_tests/test_math_unary_builtins.py | 13 +-- .../ffront_tests/test_program.py | 13 +-- .../ffront_tests/test_scalar_if.py | 102 +++--------------- .../iterator_tests/test_builtins.py | 12 +-- .../iterator_tests/test_conditional.py | 8 +- .../test_horizontal_indirection.py | 13 +-- .../iterator_tests/test_implicit_fencil.py | 5 - .../feature_tests/iterator_tests/test_scan.py | 4 +- .../test_strided_offset_provider.py | 10 +- .../iterator_tests/test_trivial.py | 1 - .../iterator_tests/test_tuple.py | 43 +++----- .../ffront_tests/test_icon_like_scan.py | 14 +-- .../iterator_tests/test_anton_toy.py | 4 +- .../iterator_tests/test_column_stencil.py | 26 +---- .../iterator_tests/test_fvm_nabla.py | 41 +------ .../iterator_tests/test_hdiff.py | 12 +-- .../iterator_tests/test_vertical_advection.py | 12 +-- .../test_with_toy_connectivity.py | 37 +++---- tests/next_tests/unit_tests/conftest.py | 39 ++++--- 29 files changed, 346 insertions(+), 403 deletions(-) create mode 100644 docs/development/ADRs/0015-Test_Exclusion_Matrices.md create mode 100644 tests/next_tests/exclusion_matrices.py diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md new file mode 100644 index 0000000000..920504db9a --- /dev/null +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -0,0 +1,80 @@ +--- +tags: [] +--- + +# Test-Exclusion Matrices + +- **Status**: valid +- **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes) +- **Created**: 2023-09-21 +- **Updated**: 2023-09-21 + +In the context of Field View testing, lacking support for specific ITIR features while a certain backend +is being developed, we decided to use `pytest` fixtures to exclude unsupported tests. + +## Context + +It should be possible to run Field View tests on different backends. However, specific tests could be unsupported +on a certain backend, or the backend implementation could be only partially ready. +Therefore, we need a mechanism to specify the features required by each test and selectively enable +the supported backends, while keeping the test code clean. + +## Decision + +It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test +on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers. +If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend. +If no marker is specified, the test is supposed to run on all backends. + +In the example below, `test_offset_field` requires the backend to support dynamic offsets in the translation from ITIR: + +```python +@pytest.mark.uses_dynamic_offsets +def test_offset_field(cartesian_case): +``` + +In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX` +lists for each backend the features that are not supported. +The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend. +If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`). + +The test-exclusion matrix is a dictionary, where `key` is the backend name and each entry is a tuple with the following fields: + +`(, , )` + +The backend string, used both as dictionary key and as string formatter in the skip message, is retrieved +by calling `tests.next_tests.get_processor_id()`, which returns the so-called processor name. +The following backend processors are defined: + +```python +DACE = "dace_iterator.run_dace_iterator" +GTFN_CPU = "otf_compile_executor.run_gtfn" +GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" +GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +``` + +Following the previous example, the GTFN backend with temporaries does not support yet dynamic offsets in ITIR: + +```python +BACKEND_SKIP_TEST_MATRIX = { + GTFN_CPU_WITH_TEMPORARIES: [ + ("uses_dynamic_offsets", pytest.XFAIL, "'{marker}' tests not supported by '{backend}' backend"), + ] +} +``` + +## Consequences + +Positive outcomes of this decision: + +- The solution provides a central place to specify test exclusion. +- The test code remains clean from if-statements for backend exclusion. +- The exclusion matrix gives an overview of the feature-readiness of different backends. + +Negative outcomes: + +- There is not (yet) any code-style check to enforce this solution, so code reviews should be aware of the ADR. + +## References + +- [pytest - Using markers to pass data to fixtures](https://docs.pytest.org/en/6.2.x/fixture.html#using-markers-to-pass-data-to-fixtures) diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/Index.md index 1bbfd62d81..09d2273ee9 100644 --- a/docs/development/ADRs/Index.md +++ b/docs/development/ADRs/Index.md @@ -51,9 +51,9 @@ _None_ - [0011 - On The Fly Compilation](0011-On_The_Fly_Compilation.md) - [0012 - GridTools C++ OTF](0011-_GridTools_Cpp_OTF.md) -### Miscellanea +### Testing -_None_ +- [0015 - Exclusion Matrices](0015-Test_Exclusion_Matrices.md) ### Superseded diff --git a/pyproject.toml b/pyproject.toml index e915622857..e2d2a7dfe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -326,9 +326,25 @@ module = 'gt4py.next.iterator.runtime' [tool.pytest.ini_options] markers = [ - 'requires_atlas', # mark tests that require 'atlas4py' bindings package - 'requires_dace', # mark tests that require 'dace' package - 'requires_gpu:' # mark tests that require a NVidia GPU (assume 'cupy' and 'cudatoolkit' are installed) + 'requires_atlas: tests that require `atlas4py` bindings package', + 'requires_dace: tests that require `dace` package', + 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', + 'uses_applied_shifts: tests that require backend support for applied-shifts', + 'uses_can_deref: tests that require backend support for can_deref', + 'uses_constant_fields: tests that require backend support for constant fields', + 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', + 'uses_if_stmts: tests that require backend support for if-statements', + 'uses_index_fields: tests that require backend support for index fields', + 'uses_lift_expressions: tests that require backend support for lift expressions', + 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', + 'uses_origin: tests that require backend support for domain origin', + 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', + 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_sparse_fields: tests that require backend support for sparse fields', + 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', + 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_returns: tests that require backend support for tuple results', + 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py new file mode 100644 index 0000000000..d0a44080ad --- /dev/null +++ b/tests/next_tests/exclusion_matrices.py @@ -0,0 +1,89 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import pytest + + +""" +Contains definition of test-exclusion matrices, see ADR 15. +""" + +# Skip definitions +XFAIL = pytest.xfail +SKIP = pytest.skip + +# Skip messages (available format keys: 'marker', 'backend') +UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" +BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" + +# Processor ids as returned by next_tests.get_processor_id() +DACE = "dace_iterator.run_dace_iterator" +GTFN_CPU = "otf_compile_executor.run_gtfn" +GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" +GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" + +# Test markers +REQUIRES_ATLAS = "requires_atlas" +USES_APPLIED_SHIFTS = "uses_applied_shifts" +USES_CAN_DEREF = "uses_can_deref" +USES_CONSTANT_FIELDS = "uses_constant_fields" +USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" +USES_IF_STMTS = "uses_if_stmts" +USES_INDEX_FIELDS = "uses_index_fields" +USES_LIFT_EXPRESSIONS = "uses_lift_expressions" +USES_NEGATIVE_MODULO = "uses_negative_modulo" +USES_ORIGIN = "uses_origin" +USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" +USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" +USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLE_RETURNS = "uses_tuple_returns" +USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" + +# Common list of feature markers to skip +GTFN_SKIP_TEST_LIST = [ + (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), +] + +""" +Skip matrix, contains for each backend processor a list of tuples with following fields: +(, ) +""" +BACKEND_SKIP_TEST_MATRIX = { + DACE: GTFN_SKIP_TEST_LIST + + [ + (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), + (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + ], + GTFN_CPU: GTFN_SKIP_TEST_LIST, + GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + + [ + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ], +} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a8c35cc28f..d3863f5a28 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,7 +22,17 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import embedded, ir as itir -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip +from tests.next_tests import exclusion_matrices + + +try: + from gt4py.next.program_processors.runners import dace_iterator +except ModuleNotFoundError as e: + if "dace" in str(e): + dace_iterator = None + else: + raise e import next_tests @@ -32,20 +42,33 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non raise ValueError("No backend selected! Backend selection is mandatory in tests.") +OPTIONAL_PROCESSORS = [] +if dace_iterator: + OPTIONAL_PROCESSORS.append(dace_iterator.run_dace_iterator) + + @pytest.fixture( params=[ roundtrip.executor, gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ], + ] + + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p), ) def fieldview_backend(request): + backend = request.param + backend_id = next_tests.get_processor_id(backend) + + """See ADR 15.""" + for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) + backup_backend = decorator.DEFAULT_BACKEND decorator.DEFAULT_BACKEND = no_backend - yield request.param + yield backend decorator.DEFAULT_BACKEND = backup_backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 71e31542f7..1402649127 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -24,7 +24,7 @@ from gt4py.next.errors.exceptions import TypeError_ from gt4py.next.ffront.decorator import field_operator, program, scan_operator from gt4py.next.ffront.fbuiltins import broadcast, int32, int64 -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -169,15 +169,8 @@ def testee( ) +@pytest.mark.uses_scan_in_field_operator def test_call_scan_operator_from_field_operator(cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("Calling scan from field operator not fully supported.") - @scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: return state + x + 2.0 * y diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f50f16ea0f..865950eeab 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -33,7 +33,7 @@ where, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -68,10 +68,8 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) +@pytest.mark.uses_tuple_returns def test_multicopy(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases.IJKField]: return a, b @@ -161,10 +159,8 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:]) +@pytest.mark.uses_tuple_returns def test_tuples(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatField: inps = a, b @@ -211,10 +207,8 @@ def testee(a: int32) -> cases.VField: ) +@pytest.mark.uses_index_fields def test_scalar_arg_with_field(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: index fields, constant fields") - @gtx.field_operator def testee(a: cases.IJKField, b: int32) -> cases.IJKField: tmp = b * a @@ -272,16 +266,8 @@ def testee(qc: cases.IKFloatField, scalar: float): cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) +@pytest.mark.uses_scan_in_field_operator def test_tuple_scalar_scan(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - ]: - pytest.xfail("Scalar tuple arguments are not supported in gtfn yet.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple arguments") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( state: float, qc_in: float, tuple_scalar: tuple[float, tuple[float, float]] @@ -301,10 +287,8 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) def testee_scan(state: float, inp: float) -> float: return inp @@ -382,12 +366,8 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]: ) +@pytest.mark.uses_dynamic_offsets def test_offset_field(cartesian_case): - if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: - pytest.xfail("Dynamic offsets not supported in gtfn") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: offset fields") - ref = np.full( (cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[KDim]), True, dtype=bool ) @@ -420,10 +400,8 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD assert np.allclose(out, ref) +@pytest.mark.uses_tuple_returns def test_nested_tuple_return(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IField, b: cases.IField @@ -438,10 +416,8 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) +@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over lift expressions") - @gtx.field_operator def testee(a: cases.EField) -> cases.EField: tmp = neighbor_sum(a(V2E), axis=V2EDim) @@ -481,10 +457,8 @@ def testee(inp: cases.EField) -> cases.EField: ) +@pytest.mark.uses_tuple_returns def test_tuple_return_2(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField]: tmp = neighbor_sum(a(V2E), axis=V2EDim) @@ -502,10 +476,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) +@pytest.mark.uses_tuple_returns def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples") - @gtx.field_operator def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: tup = e(V2E), v @@ -522,10 +494,8 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: ) +@pytest.mark.uses_tuple_args def test_tuple_arg(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple args") - @gtx.field_operator def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.IField: return 3 * a[0][0] + a[0][1] + a[1] @@ -555,6 +525,7 @@ def simple_scan_operator(carry: float) -> float: cases.verify(cartesian_case, simple_scan_operator, out=out, ref=expected) +@pytest.mark.uses_lift_expressions def test_solve_triag(cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, @@ -564,8 +535,6 @@ def test_solve_triag(cartesian_case): pytest.xfail("Nested `scan`s requires creating temporaries.") if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( @@ -627,10 +596,8 @@ def testee(left: int32, right: int32) -> cases.IField: @pytest.mark.parametrize("left, right", [(2, 3), (3, 2)]) +@pytest.mark.uses_tuple_returns def test_ternary_operator_tuple(cartesian_case, left, right): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee( a: cases.IField, b: cases.IField, left: int32, right: int32 @@ -646,10 +613,8 @@ def testee( ) +@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over lift expressions") - @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) @@ -688,11 +653,10 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) +@pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] @@ -720,9 +684,8 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): ) +@pytest.mark.uses_tuple_args def test_scan_nested_tuple_input(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple args") init = 1.0 k_size = cartesian_case.default_sizes[KDim] inp1 = gtx.np_as_located_field(KDim)(np.ones((k_size,))) @@ -877,10 +840,8 @@ def program_domain( ) +@pytest.mark.uses_tuple_returns def test_domain_tuple(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def fieldop_domain_tuple( a: cases.IJField, b: cases.IJField @@ -939,10 +900,8 @@ def return_undefined(): return undefined_symbol +@pytest.mark.uses_zero_dimensional_fields def test_zero_dims_fields(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: zero-dimensional fields") - @gtx.field_operator def implicit_broadcast_scalar(inp: cases.EmptyField): return inp @@ -970,10 +929,8 @@ def fieldop_implicit_broadcast_2(inp: cases.IField) -> cases.IField: ) +@pytest.mark.uses_tuple_returns def test_tuple_unpacking(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def unpack( inp: cases.IField, @@ -986,9 +943,8 @@ def unpack( ) +@pytest.mark.uses_tuple_returns def test_tuple_unpacking_star_multi(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") OutType = tuple[ cases.IField, cases.IField, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 7f2b11afff..dbc35ddfdf 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -17,7 +17,6 @@ import gt4py.next as gtx from gt4py.next import int32, neighbor_sum -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import V2E, Edge, V2EDim, Vertex, unstructured_case diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index ee88b3764e..0ae874f3a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import broadcast, float64, int32, int64, max_over, min_over, neighbor_sum, where -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -46,8 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") if unstructured_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, @@ -69,9 +67,6 @@ def testee(edge_f: cases.EField) -> cases.VField: def test_minover_execution(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -99,10 +94,8 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_constant_fields def test_reduction_expression_in_call(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: make_const_list") - @gtx.field_operator def reduce_expr(edge_f: cases.EField) -> cases.VField: tmp_nbh_tup = edge_f(V2E), edge_f(V2E) @@ -133,10 +126,8 @@ def testee(flux: cases.EField) -> cases.VField: ) +@pytest.mark.uses_tuple_returns def test_conditional_nested_tuple(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def conditional_nested_tuple( mask: cases.IBoolField, a: cases.IFloatField, b: cases.IFloatField diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 9ceab7f2d0..f7121dc82f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -22,7 +22,6 @@ from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.decorator import FieldOperator from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.program_processors.runners import dace_iterator from gt4py.next.type_system import type_translation from next_tests.integration_tests import cases diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 54374077b4..85826c1ac0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -37,7 +37,7 @@ tanh, trunc, ) -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case @@ -84,17 +84,8 @@ def floorDiv(inp1: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, floorDiv, ref=lambda inp1: inp1 // 2) +@pytest.mark.uses_negative_modulo def test_mod(cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail( - "Modulo not properly supported for negative numbers." - ) # see https://github.com/GridTools/gt4py/issues/1219 - @gtx.field_operator def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index d7c50e83f0..f489126fa7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,7 +20,6 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import dace_iterator from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -129,10 +128,8 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField): ) +@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField @@ -158,10 +155,8 @@ def prog( assert np.allclose((a, b), (out_a, out_b)) +@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField @@ -188,10 +183,8 @@ def prog( assert out_a[0] == 0 and out_b[0] == 0 +@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_nested(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField, c: cases.IFloatField diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index a49dd1fdcf..f9fd2c1353 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -19,7 +19,6 @@ import pytest from gt4py.next import Field, errors, field_operator, float64, index_field, np_as_located_field -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -46,15 +45,8 @@ @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_simple_if(condition, cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -71,15 +63,8 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) +@pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def simple_if( a: cases.IField, @@ -112,15 +97,8 @@ def simple_if( @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_local_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def local_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -138,15 +116,8 @@ def local_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_temporary_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -167,15 +138,8 @@ def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IFi @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_return(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -196,15 +160,8 @@ def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IFi @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_if_branch_returns(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def if_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -222,15 +179,8 @@ def if_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> case @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_else_branch_returns(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def else_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -250,15 +200,8 @@ def else_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> ca @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_both_branches_return(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -278,15 +221,8 @@ def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> c @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) -def test_nested_if_stmt_conditinal(cartesian_case, condition1, condition2): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - +@pytest.mark.uses_if_stmts +def test_nested_if_stmt_conditional(cartesian_case, condition1, condition2): @field_operator def nested_if_conditional_return( inp: cases.IField, condition1: bool, condition2: bool @@ -322,15 +258,8 @@ def nested_if_conditional_return( @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_nested_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -364,15 +293,8 @@ def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) +@pytest.mark.uses_if_stmts def test_if_without_else(cartesian_case, condition1, condition2): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def if_without_else( a: cases.IField, b: cases.IField, condition1: bool, condition2: bool diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 13fcf3b87f..ca29c5b18b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -52,7 +52,6 @@ xor_, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -171,10 +170,6 @@ def arithmetic_and_logical_test_data(): @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) inps = asfield(*asarray(*inputs)) out = asfield((np.zeros_like(*asarray(expected))))[0] @@ -207,10 +202,6 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) if builtin_name == "gamma": # numpy has no gamma function @@ -254,10 +245,9 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) +@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: can_deref") Node = gtx.Dimension("Node") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index d20ec2ee3d..c2517f1a07 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -18,7 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -27,15 +26,14 @@ @fundef -def test_conditional(inp): +def stencil_conditional(inp): tmp = if_(eq(deref(inp), 0), make_tuple(1.0, 2.0), make_tuple(3.0, 4.0)) return tuple_get(0, tmp) + tuple_get(1, tmp) +@pytest.mark.uses_tuple_returns def test_conditional_w_tuple(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5] @@ -46,7 +44,7 @@ def test_conditional_w_tuple(program_processor): IDim: range(0, shape[0]), } run_processor( - test_conditional[dom], + stencil_conditional[dom], program_processor, inp, out=out, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index f4ebc596e5..75b935677b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -34,8 +34,6 @@ from gt4py.next.program_processors.formatters.gtfn import ( format_sourcecode as gtfn_format_sourcecode, ) -from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.integration_tests.cases import IDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -54,16 +52,13 @@ def conditional_indirection(inp, cond): return deref(compute_shift(cond)(inp)) +@pytest.mark.uses_applied_shifts def test_simple_indirection(program_processor): program_processor, validate = program_processor if program_processor in [ type_check.check, - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, gtfn_format_sourcecode, - run_dace_iterator, ]: pytest.xfail( "We only support applied shifts in type_inference." @@ -97,13 +92,9 @@ def direct_indirection(inp, cond): return deref(shift(I, deref(cond))(inp)) +@pytest.mark.uses_dynamic_offsets def test_direct_offset_for_indirection(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: shift offsets not literals") - - if program_processor == gtfn_cpu.run_gtfn_with_temporaries: - pytest.xfail("Dynamic offsets not supported in temporaries pass.") shape = [4] inp = gtx.np_as_located_field(IDim)(np.asarray(range(shape[0]), dtype=np.float64)) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 2076cdd864..d0dc8ec475 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -18,7 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -59,10 +58,6 @@ def test_single_argument(program_processor, dom): def test_2_arguments(program_processor, dom): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) @fundef def fun(inp0, inp1): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index e0460b67b1..e02dab0a72 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,16 +18,14 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import cartesian_domain, deref, named_range, scan, shift from gt4py.next.iterator.runtime import fundef, offset -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +@pytest.mark.uses_index_fields def test_scan_in_stencil(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: shift inside lambda") isize = 1 ksize = 3 diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bfaa7f643..0ac38e9b9f 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -18,8 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -49,15 +47,9 @@ def fencil(size, out, inp): ) +@pytest.mark.uses_strided_neighbor_offset def test_strided_offset_provider(program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("gtx.StridedNeighborOffsetProvider not implemented in bindings.") LocA_size = 2 max_neighbors = LocA2LocAB_offset_provider.max_neighbors diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 7cc4e95949..cc12183a24 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -19,7 +19,6 @@ from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 5a6ffe2891..bd5a717bb2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -18,13 +18,8 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from next_tests.unit_tests.conftest import ( - program_processor, - program_processor_no_gtfn_exec, - run_processor, -) +from next_tests.unit_tests.conftest import program_processor, run_processor IDim = gtx.Dimension("IDim") @@ -54,10 +49,9 @@ def tuple_output2(inp1, inp2): "stencil", [tuple_output1, tuple_output2], ) +@pytest.mark.uses_tuple_returns def test_tuple_output(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -94,10 +88,9 @@ def tuple_of_tuple_output2(inp1, inp2, inp3, inp4): return make_tuple(deref(inp1), deref(inp2)), make_tuple(deref(inp3), deref(inp4)) +@pytest.mark.uses_tuple_returns def test_tuple_of_tuple_of_field_output(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fundef def stencil(inp1, inp2, inp3, inp4): @@ -155,10 +148,9 @@ def stencil(inp1, inp2, inp3, inp4): "stencil", [tuple_output1, tuple_output2], ) +@pytest.mark.uses_tuple_returns def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): @@ -202,10 +194,9 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): assert np.allclose(inp2, out2) +@pytest.mark.uses_tuple_returns def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fundef def stencil(inp1, inp2, inp3): @@ -265,10 +256,8 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): "stencil", [tuple_output1, tuple_output2], ) -def test_field_of_extra_dim_output(program_processor_no_gtfn_exec, stencil): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_extra_dim_output(program_processor, stencil): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() @@ -299,10 +288,9 @@ def tuple_input(inp): return tuple_get(0, inp_deref) + tuple_get(1, inp_deref) +@pytest.mark.uses_tuple_returns def test_tuple_field_input(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -326,10 +314,8 @@ def test_tuple_field_input(program_processor): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -def test_field_of_extra_dim_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_extra_dim_input(program_processor): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() @@ -362,10 +348,9 @@ def tuple_tuple_input(inp): ) +@pytest.mark.uses_tuple_returns def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -404,10 +389,8 @@ def test_tuple_of_tuple_of_field_input(program_processor): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -def test_field_of_2_extra_dim_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_2_extra_dim_input(program_processor): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 2580c6ba7f..8db9a4c36e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, @@ -211,6 +211,7 @@ class setup: return setup() +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): if fieldview_backend in [ gtfn_cpu.run_gtfn, @@ -218,8 +219,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): gtfn_cpu.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -233,6 +232,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail( @@ -241,8 +241,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): ) if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples, scans") solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -256,11 +254,10 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( test_setup.z_alpha, test_setup.z_beta, @@ -274,13 +271,12 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): assert np.allclose(test_setup.w_ref, test_setup.w) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples, scans") solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( test_setup.z_alpha, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 14d929e822..16d839a8ab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -19,7 +19,6 @@ from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -75,6 +74,7 @@ def naive_lap(inp): return out +@pytest.mark.uses_origin def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor @@ -87,8 +87,6 @@ def test_anton_toy(program_processor, lift_mode): if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("TODO: issue with temporaries that crashes the application") - if program_processor == run_dace_iterator: - pytest.xfail("TODO: not supported in DaCe backend") shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 2446d6664f..41d6c8f0f9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -18,11 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.formatters.gtfn import ( - format_sourcecode as gtfn_format_sourcecode, -) -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -79,11 +74,10 @@ def basic_stencils(request): return request.param +@pytest.mark.uses_origin def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): program_processor, validate = program_processor stencil, ref_fun, inp_fun = basic_stencils - if program_processor == run_dace_iterator and inp_fun: - pytest.xfail("Not supported in DaCe backend: origin") shape = [5, 7] inp = ( @@ -95,13 +89,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ref = ref_fun(inp) - if ( - program_processor == run_dace_iterator - and stencil.__name__ == "shift_stencil" - and inp.origin - ): - pytest.xfail("Not supported in DaCe backend: origin") - run_processor( stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}], program_processor, @@ -162,12 +149,10 @@ def k_level_condition_upper_tuple(k_idx, k_level): ), ], ) +@pytest.mark.uses_tuple_returns def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple arguments") - k_size = 5 inp = inp_function(k_size) ref = ref_function(inp) @@ -361,10 +346,6 @@ def sum_shifted_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) @@ -401,10 +382,9 @@ def sum_fencil(out, inp0, inp1, k_size): ) +@pytest.mark.uses_origin def test_different_vertical_sizes_with_origin(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: origin") k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 2d35fb1e50..42de13ef44 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -15,8 +15,6 @@ import numpy as np import pytest -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator - pytest.importorskip("atlas4py") @@ -136,15 +134,9 @@ def nabla( ) +@pytest.mark.requires_atlas def test_compute_zavgS(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() pp = gtx.np_as_located_field(Vertex)(setup.input_field) @@ -201,15 +193,9 @@ def compute_zavgS2_fencil( ) +@pytest.mark.requires_atlas def test_compute_zavgS2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() pp = gtx.np_as_located_field(Vertex)(setup.input_field) @@ -244,15 +230,9 @@ def test_compute_zavgS2(program_processor, lift_mode): assert_close(1000788897.3202186, np.max(zavgS[1])) +@pytest.mark.requires_atlas def test_nabla(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") setup = nabla_setup() @@ -310,15 +290,9 @@ def nabla2( ) +@pytest.mark.requires_atlas def test_nabla2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() sign = gtx.np_as_located_field(Vertex, V2EDim)(setup.sign_field) @@ -400,13 +374,6 @@ def test_nabla_sign(program_processor, lift_mode): program_processor, validate = program_processor if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("test is broken due to bad lift semantics in iterator IR") - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() is_pole_edge = gtx.np_as_located_field(Edge)(setup.is_pole_edge_field) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 1dfad40e48..7bd028b7c3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -24,12 +24,7 @@ from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( hdiff_reference, ) -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor I = offset("I") @@ -76,8 +71,9 @@ def hdiff(inp, coeff, out, x, y): ) -def test_hdiff(hdiff_reference, program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_origin +def test_hdiff(hdiff_reference, program_processor, lift_mode): + program_processor, validate = program_processor if program_processor in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 4474121876..f11046cb5d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -25,12 +25,7 @@ from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @fundef @@ -120,8 +115,9 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) -def test_tridiag(fencil, tridiag_reference, program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_lift_expressions +def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): + program_processor, validate = program_processor if ( program_processor in [ diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ee07372731..27c9f6d124 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -32,7 +32,6 @@ from gt4py.next.iterator.runtime import fundef from gt4py.next.program_processors.formatters import gtfn from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.toy_connectivity import ( C2E, @@ -54,7 +53,6 @@ from next_tests.unit_tests.conftest import ( lift_mode, program_processor, - program_processor_no_dace_exec, program_processor_no_gtfn_exec, run_processor, ) @@ -139,10 +137,9 @@ def map_make_const_list(in_edges): return reduce(plus, 0)(map_(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) +@pytest.mark.uses_constant_fields def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: make_const_list") inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -244,8 +241,9 @@ def slice_sparse_stencil(sparse): return list_get(1, deref(sparse)) -def test_slice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_slice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -298,8 +296,9 @@ def shift_sliced_sparse_stencil(sparse): return list_get(1, deref(shift(V2V, 0)(sparse))) -def test_shift_sliced_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sliced_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -325,8 +324,9 @@ def slice_shifted_sparse_stencil(sparse): return list_get(1, deref(shift(V2V, 0)(sparse))) -def test_slice_shifted_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_slice_shifted_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -357,8 +357,8 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) -def test_lift(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_lift(program_processor, lift_mode): + program_processor, validate = program_processor inp = vertex_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) @@ -380,8 +380,9 @@ def sparse_shifted_stencil(inp): return list_get(2, list_get(0, neighbors(V2V, inp))) -def test_shift_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) @@ -409,8 +410,9 @@ def shift_sparse_stencil2(inp): return list_get(1, list_get(3, neighbors(V2E, inp))) -def test_shift_sparse_input_field2(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sparse_input_field2(program_processor, lift_mode): + program_processor, validate = program_processor if program_processor in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, @@ -459,13 +461,12 @@ def sum_(a, b): return reduce(sum_, 0)(neighbors(V2V, lift(lambda x: reduce(sum_, 0)(deref(x)))(inp))) +@pytest.mark.uses_sparse_fields def test_sparse_shifted_stencil_reduce(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec if program_processor == gtfn.format_sourcecode: pytest.xfail("We cannot unroll a reduction on a sparse field only.") # With our current understanding, this iterator IR program is illegal, however we might want to fix it and therefore keep the test for now. - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: illegal iterator IR") if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 09d58a4376..04c34dfaab 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -23,12 +23,17 @@ from gt4py.next.iterator import ir as itir, pretty_parser, pretty_printer, runtime, transforms from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.formatters import gtfn, lisp, type_check -from gt4py.next.program_processors.runners import ( - dace_iterator, - double_roundtrip, - gtfn_cpu, - roundtrip, -) +from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip +from tests.next_tests import exclusion_matrices + + +try: + from gt4py.next.program_processors.runners import dace_iterator +except ModuleNotFoundError as e: + if "dace" in str(e): + dace_iterator = None + else: + raise e import next_tests @@ -60,6 +65,11 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str return pretty +OPTIONAL_PROCESSORS = [] +if dace_iterator: + OPTIONAL_PROCESSORS.append((dace_iterator.run_dace_iterator, True)) + + @pytest.fixture( params=[ # (processor, do_validate) @@ -73,19 +83,20 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str (gtfn_cpu.run_gtfn_imperative, True), (gtfn_cpu.run_gtfn_with_temporaries, True), (gtfn.format_sourcecode, False), - (dace_iterator.run_dace_iterator, True), - ], + ] + + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p[0]), ) def program_processor(request): - return request.param + backend, _ = request.param + backend_id = next_tests.get_processor_id(backend) + """See ADR 15.""" + for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) -@pytest.fixture -def program_processor_no_dace_exec(program_processor): - if program_processor[0] == dace_iterator.run_dace_iterator: - pytest.xfail("DaCe backend not yet supported.") - return program_processor + return request.param @pytest.fixture From 54bca831100455741e5bed459ef8378a1418b5ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 4 Oct 2023 15:57:35 +0200 Subject: [PATCH 17/67] Fixes and additions to test exclusion matrices functionality. (#1345) Fixes and additions to test exclusion matrices. Changes: - Fix import path of exclusion matrices. - Fix wrong locations of docstrings. - Remove deprecated fixtures. - Add missing marker to parametrize forgotten custom case. --- .../ADRs/0015-Test_Exclusion_Matrices.md | 2 +- tests/next_tests/__init__.py | 5 ++++ tests/next_tests/exclusion_matrices.py | 26 +++++++++++-------- .../ffront_tests/ffront_test_utils.py | 12 ++++++--- .../test_with_toy_connectivity.py | 20 +++++--------- tests/next_tests/unit_tests/conftest.py | 23 +++++++--------- 6 files changed, 45 insertions(+), 43 deletions(-) diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md index 920504db9a..6c6a043560 100644 --- a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -43,7 +43,7 @@ The test-exclusion matrix is a dictionary, where `key` is the backend name and e `(, , )` The backend string, used both as dictionary key and as string formatter in the skip message, is retrieved -by calling `tests.next_tests.get_processor_id()`, which returns the so-called processor name. +by calling `next_tests.get_processor_id()`, which returns the so-called processor name. The following backend processors are defined: ```python diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index bd9b968948..54bc4d9c69 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -12,6 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from . import exclusion_matrices + + +__all__ = ["exclusion_matrices", "get_processor_id"] + def get_processor_id(processor): if hasattr(processor, "__module__") and hasattr(processor, "__name__"): diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index d0a44080ad..27ccb29095 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -14,23 +14,18 @@ import pytest -""" -Contains definition of test-exclusion matrices, see ADR 15. -""" +"""Contains definition of test-exclusion matrices, see ADR 15.""" # Skip definitions XFAIL = pytest.xfail SKIP = pytest.skip -# Skip messages (available format keys: 'marker', 'backend') -UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" -BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" - # Processor ids as returned by next_tests.get_processor_id() DACE = "dace_iterator.run_dace_iterator" GTFN_CPU = "otf_compile_executor.run_gtfn" GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +GTFN_FORMAT_SOURCECODE = "gtfn.format_sourcecode" # Test markers REQUIRES_ATLAS = "requires_atlas" @@ -46,25 +41,31 @@ USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" +# Skip messages (available format keys: 'marker', 'backend') +UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" +BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" +REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE = ( + "We cannot unroll a reduction on a sparse field only (not clear if it is legal ITIR)" +) # Common list of feature markers to skip GTFN_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ] -""" -Skip matrix, contains for each backend processor a list of tuples with following fields: -(, ) -""" +#: Skip matrix, contains for each backend processor a list of tuples with following fields: +#: (, ) BACKEND_SKIP_TEST_MATRIX = { DACE: GTFN_SKIP_TEST_LIST + [ @@ -86,4 +87,7 @@ + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), ], + GTFN_FORMAT_SOURCECODE: [ + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), + ], } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index d3863f5a28..383716484e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -23,7 +23,6 @@ from gt4py.next.ffront import decorator from gt4py.next.iterator import embedded, ir as itir from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip -from tests.next_tests import exclusion_matrices try: @@ -58,11 +57,18 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non ids=lambda p: next_tests.get_processor_id(p), ) def fieldview_backend(request): + """ + Fixture creating field-view operator backend on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ backend = request.param backend_id = next_tests.get_processor_id(backend) - """See ADR 15.""" - for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): if request.node.get_closest_marker(marker): skip_mark(msg.format(marker=marker, backend=backend_id)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 27c9f6d124..92b93ddb63 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -50,12 +50,7 @@ v2e_arr, v2v_arr, ) -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_gtfn_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings @@ -138,8 +133,8 @@ def map_make_const_list(in_edges): @pytest.mark.uses_constant_fields -def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec +def test_map_make_const_list(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -462,12 +457,9 @@ def sum_(a, b): @pytest.mark.uses_sparse_fields -def test_sparse_shifted_stencil_reduce(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == gtfn.format_sourcecode: - pytest.xfail("We cannot unroll a reduction on a sparse field only.") - # With our current understanding, this iterator IR program is illegal, however we might want to fix it and therefore keep the test for now. - +@pytest.mark.uses_reduction_with_only_sparse_fields +def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): + program_processor, validate = program_processor if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 04c34dfaab..7a62778be1 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -24,7 +24,6 @@ from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.formatters import gtfn, lisp, type_check from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip -from tests.next_tests import exclusion_matrices try: @@ -88,28 +87,24 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str ids=lambda p: next_tests.get_processor_id(p[0]), ) def program_processor(request): + """ + Fixture creating program processors on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ backend, _ = request.param backend_id = next_tests.get_processor_id(backend) - """See ADR 15.""" - for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): if request.node.get_closest_marker(marker): skip_mark(msg.format(marker=marker, backend=backend_id)) return request.param -@pytest.fixture -def program_processor_no_gtfn_exec(program_processor): - if ( - program_processor[0] == gtfn_cpu.run_gtfn - or program_processor[0] == gtfn_cpu.run_gtfn_imperative - or program_processor[0] == gtfn_cpu.run_gtfn_with_temporaries - ): - pytest.xfail("gtfn backend not yet supported.") - return program_processor - - def run_processor( program: runtime.FendefDispatcher, processor: ppi.ProgramExecutor | ppi.ProgramFormatter, From 0d821b150177d8a805df37887fd74427a751e5af Mon Sep 17 00:00:00 2001 From: ninaburg <83002751+ninaburg@users.noreply.github.com> Date: Thu, 5 Oct 2023 15:16:48 +0200 Subject: [PATCH 18/67] feat[next]: Add support for using Type Aliases (#1335) * Add Type Alias replacement pass + tests * Fix: actual type not added in symbol list if already present * Address requested changes * Pre-commit fixes * Address requested changes * Prevent multiple float32 or float64 definitions in symtable * pre-commit run changes and 'returns' arg type modifications * Use 'from_type_hint' to avoid 'ScalarKind' construct --------- Co-authored-by: Nina Burgdorfer --- .../foast_passes/type_alias_replacement.py | 105 ++++++++++++++++++ src/gt4py/next/ffront/func_to_foast.py | 2 + .../test_type_alias_replacement.py | 44 ++++++++ 3 files changed, 151 insertions(+) create mode 100644 src/gt4py/next/ffront/foast_passes/type_alias_replacement.py create mode 100644 tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py diff --git a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py new file mode 100644 index 0000000000..c5857999ee --- /dev/null +++ b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py @@ -0,0 +1,105 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from typing import Any, cast + +import gt4py.next.ffront.field_operator_ast as foast +from gt4py.eve import NodeTranslator, traits +from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef +from gt4py.next.ffront import dialect_ast_enums +from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system.type_translation import from_type_hint + + +@dataclass +class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait): + """ + Replace Type Aliases with their actual type. + + After this pass, the type aliases used for explicit construction of literal + values and for casting field values are replaced by their actual types. + """ + + closure_vars: dict[str, Any] + + @classmethod + def apply( + cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] + ) -> tuple[foast.FunctionDefinition, dict[str, Any]]: + foast_node = cls(closure_vars=closure_vars).visit(node) + new_closure_vars = closure_vars.copy() + for key, value in closure_vars.items(): + if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES: + new_closure_vars[value.__name__] = closure_vars[key] + return foast_node, new_closure_vars + + def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool: + return ( + node_id in self.closure_vars + and isinstance(self.closure_vars[node_id], type) + and node_id not in TYPE_BUILTIN_NAMES + ) + + def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: + if self.is_type_alias(node.id): + return foast.Name( + id=self.closure_vars[node.id].__name__, location=node.location, type=node.type + ) + return node + + def _update_closure_var_symbols( + self, closure_vars: list[foast.Symbol], location: SourceLocation + ) -> list[foast.Symbol]: + new_closure_vars: list[foast.Symbol] = [] + existing_type_names: set[str] = set() + + for var in closure_vars: + if self.is_type_alias(var.id): + actual_type_name = self.closure_vars[var.id].__name__ + # Avoid multiple definitions of a type in closure_vars + if actual_type_name not in existing_type_names: + new_closure_vars.append( + foast.Symbol( + id=actual_type_name, + type=ts.FunctionType( + pos_or_kw_args={}, + kw_only_args={}, + pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)], + returns=cast( + ts.DataType, from_type_hint(self.closure_vars[var.id]) + ), + ), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=location, + ) + ) + existing_type_names.add(actual_type_name) + elif var.id not in existing_type_names: + new_closure_vars.append(var) + existing_type_names.add(var.id) + + return new_closure_vars + + def visit_FunctionDefinition( + self, node: foast.FunctionDefinition, **kwargs + ) -> foast.FunctionDefinition: + return foast.FunctionDefinition( + id=node.id, + params=node.params, + body=self.visit(node.body, **kwargs), + closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location), + location=node.location, + ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 082939c938..c7c4c3a23f 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -33,6 +33,7 @@ from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass +from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -91,6 +92,7 @@ def _postprocess_dialect_ast( closure_vars: dict[str, Any], annotations: dict[str, Any], ) -> foast.FunctionDefinition: + foast_node, closure_vars = TypeAliasReplacement.apply(foast_node, closure_vars) foast_node = ClosureVarFolding.apply(foast_node, closure_vars) foast_node = DeadClosureVarElimination.apply(foast_node) foast_node = ClosureVarTypeDeduction.apply(foast_node, closure_vars) diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py new file mode 100644 index 0000000000..e87f869352 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import ast +import typing +from typing import TypeAlias + +import pytest + +import gt4py.next as gtx +from gt4py.next import float32, float64 +from gt4py.next.ffront.fbuiltins import astype +from gt4py.next.ffront.func_to_foast import FieldOperatorParser + + +TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. +vpfloat: TypeAlias = float32 +wpfloat: TypeAlias = float64 + + +@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")]) +def test_type_alias_replacement(test_input, expected): + def fieldop_with_typealias( + a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32] + ) -> gtx.Field[[TDim], test_input]: + return test_input("3.1418") + astype(a, test_input) + + foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias) + + assert ( + foast_tree.body.stmts[0].value.left.func.id == expected + and foast_tree.body.stmts[0].value.right.args[1].id == expected + ) From 6c69398e576f5b8598f96dad617931dda32f62bf Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Oct 2023 10:56:00 +0200 Subject: [PATCH 19/67] feat[next-dace]: Add support for GPU execution (#1347) This PR adds support for GPU execution in DaCe Backend. Additionally, it also introduces a build cache for each visited ITIR program and corresponding binary DaCe program. --- .../runners/dace_iterator/__init__.py | 101 ++++++++++++++++-- 1 file changed, 91 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..25609b1035 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -16,6 +16,8 @@ import dace import numpy as np +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir from gt4py.next import common @@ -29,6 +31,14 @@ from .utility import connectivity_identifier, filter_neighbor_tables +""" Default build configuration in DaCe backend """ +_build_type = "Release" +# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins +_cpu_args = ( + "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" +) + + def convert_arg(arg: Any): if common.is_field(arg): sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) @@ -85,17 +95,67 @@ def get_stride_args( return stride_args +_build_cache_cpu: dict[int, CompiledSDFG] = {} +_build_cache_gpu: dict[int, CompiledSDFG] = {} + + +def get_cache_id(*cache_args) -> int: + return sum([hash(str(arg)) for arg in cache_args]) + + @program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: + # build parameters + auto_optimize = kwargs.get("auto_optimize", False) + build_type = kwargs.get("build_type", "RelWithDebInfo") + run_on_gpu = kwargs.get("run_on_gpu", False) + build_cache = kwargs.get("build_cache", None) + # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] - neighbor_tables = filter_neighbor_tables(offset_provider) - program = preprocess_program(program, offset_provider) arg_types = [type_translation.from_value(arg) for arg in args] - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg: dace.SDFG = sdfg_genenerator.visit(program) - sdfg.simplify() + neighbor_tables = filter_neighbor_tables(offset_provider) + + cache_id = get_cache_id(program, *arg_types, column_axis) + if build_cache is not None and cache_id in build_cache: + # retrieve SDFG program from build cache + sdfg_program = build_cache[cache_id] + sdfg = sdfg_program.sdfg + else: + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # set array storage for GPU execution + if run_on_gpu: + device = dace.DeviceType.GPU + sdfg._name = f"{sdfg.name}_gpu" + for _, _, array in sdfg.arrays_recursive(): + if not array.transient: + array.storage = dace.dtypes.StorageType.GPU_Global + else: + device = dace.DeviceType.CPU + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + + # compile SDFG and retrieve SDFG program + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "build_type", value=build_type) + dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + sdfg_program = sdfg.compile(validate=False) + + # store SDFG program in build cache + if build_cache is not None: + build_cache[cache_id] = sdfg_program dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} @@ -105,8 +165,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_strides = get_stride_args(sdfg.arrays, dace_field_args) dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" - all_args = { **dace_args, **dace_conn_args, @@ -120,9 +178,32 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: for key, value in all_args.items() if key in sdfg.signature_arglist(with_types=False) } + with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("compiler", "build_type", value="Debug") - dace.config.Config.set("compiler", "cpu", "args", value="-O0") dace.config.Config.set("frontend", "check_args", value=True) - sdfg(**expected_args) + sdfg_program(**expected_args) + + +@program_executor +def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_cpu, + build_type=_build_type, + run_on_gpu=False, + ) + + +@program_executor +def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_gpu, + build_type=_build_type, + run_on_gpu=True, + ) From d07104da19d0e3b467210e9afb940214cddad79a Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Oct 2023 11:07:46 +0200 Subject: [PATCH 20/67] fix[next-dace]: scan_dim consistent with canonical field domain (#1346) The DaCe backend is reordering the dimensions of field domain based on alphabetical order - we call this the canonical representation of field domain. Therefore, array strides, sizes and offsets need to be shuffled, everywhere, to be consistent with the alphabetical order of dimensions. This PR corrects indexing of field domain in get_scan_dim() which was not consistent with the canonical representation. Additional minor edit: * rename map_domain -> map_ranges * replace dace.Memlet() with dace.Memlet.simple() --- .../runners/dace_iterator/itir_to_sdfg.py | 55 ++++++++++--------- .../runners/dace_iterator/itir_to_tasklet.py | 12 ++-- .../runners/dace_iterator/utility.py | 9 ++- 3 files changed, 40 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 56031d8555..2b4ad721b8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ create_memlet_at, create_memlet_full, filter_neighbor_tables, + get_sorted_dims, map_nested_sdfg_symbols, unique_var_name, ) @@ -79,9 +80,10 @@ def get_scan_dim( - scan_dim_dtype: data type along the scan dimension """ output_type = cast(ts.FieldType, storage_types[output.id]) + sorted_dims = [dim for _, dim in get_sorted_dims(output_type.dims)] return ( column_axis.value, - output_type.dims.index(column_axis), + sorted_dims.index(column_axis), output_type.dtype, ) @@ -246,7 +248,7 @@ def visit_StencilClosure( ) access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) - memlet = create_memlet_at(out_name, ("0",)) + memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: @@ -274,7 +276,7 @@ def visit_StencilClosure( transient_to_arg_name_mapping[nsdfg_output_name] = output_name # scan operator should always be the first function call in a closure if is_scan(node.stencil): - nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure( + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( node, closure_sdfg.arrays, closure_domain, nsdfg_output_name ) results = [nsdfg_output_name] @@ -294,13 +296,13 @@ def visit_StencilClosure( output_name, tuple( f"i_{dim}" - if f"i_{dim}" in map_domain + if f"i_{dim}" in map_ranges else f"0:{output_descriptor.shape[scan_dim_index]}" for dim, _ in closure_domain ), ) else: - nsdfg, map_domain, results = self._visit_parallel_stencil_closure( + nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) assert len(results) == 1 @@ -313,7 +315,7 @@ def visit_StencilClosure( transient=True, ) - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) + output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} @@ -325,7 +327,7 @@ def visit_StencilClosure( nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, - map_ranges=map_domain or {"__dummy": "0"}, + map_ranges=map_ranges or {"__dummy": "0"}, inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, @@ -341,10 +343,10 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet(data=memlet.data, subset=output_subset), + dace.Memlet.simple(memlet.data, output_subset), ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset + inner_memlet = dace.Memlet.simple( + memlet.data, output_subset, other_subset_str=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) @@ -360,7 +362,7 @@ def visit_StencilClosure( None, map_entry, b.value.data, - create_memlet_at(b.value.data, ("0",)), + dace.Memlet.simple(b.value.data, "0"), ) return closure_sdfg @@ -390,12 +392,12 @@ def _visit_scan_stencil_closure( connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value if not dim == scan_dim: - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" else: scan_lb_str = lb_str scan_ub_str = ub_str @@ -481,29 +483,28 @@ def _visit_scan_stencil_closure( "__result", carry_node1, None, - dace.Memlet(data=f"{scan_carry_name}", subset="0"), + dace.Memlet.simple(scan_carry_name, "0"), ) carry_node2 = lambda_state.add_access(scan_carry_name) lambda_state.add_memlet_path( carry_node2, scan_inner_node, - memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"), + memlet=dace.Memlet.simple(scan_carry_name, "0"), src_conn=None, dst_conn=lambda_carry_name, ) # connect access nodes to lambda inputs for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - data_subset = ( - ", ".join([f"i_{dim}" for dim, _ in closure_domain]) - if isinstance(self.storage_types[data_name], ts.FieldType) - else "0" - ) + if isinstance(self.storage_types[data_name], ts.FieldType): + memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain)) + else: + memlet = dace.Memlet.simple(data_name, "0") lambda_state.add_memlet_path( lambda_state.add_access(data_name), scan_inner_node, - memlet=dace.Memlet(data=f"{data_name}", subset=data_subset), + memlet=memlet, src_conn=None, dst_conn=inner_name, ) @@ -532,7 +533,7 @@ def _visit_scan_stencil_closure( lambda_state.add_memlet_path( scan_inner_node, lambda_state.add_access(data_name), - memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"), + memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, ) @@ -544,10 +545,10 @@ def _visit_scan_stencil_closure( lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"), + memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) - return scan_sdfg, map_domain, scan_dim_index + return scan_sdfg, map_ranges, scan_dim_index def _visit_parallel_stencil_closure( self, @@ -562,11 +563,11 @@ def _visit_parallel_stencil_closure( conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" # Create an SDFG for the tasklet that computes a single item of the output domain. index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} @@ -583,7 +584,7 @@ def _visit_parallel_stencil_closure( self.node_types, ) - return context.body, map_domain, [r.value.data for r in results] + return context.body, map_ranges, [r.value.data for r in results] def _visit_domain( self, node: itir.FunCall, context: Context diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e7a598d9a..d3bfb5ff0e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -34,7 +34,6 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, - create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, @@ -595,7 +594,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - array_index = [ + flat_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name for dim in sorted(iterator.dimensions) ] @@ -608,7 +607,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: name="deref", inputs=set(internals), outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", ) for arg, internal in zip(args, internals): @@ -634,8 +633,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: flat_index = [ ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions ] - - args = [ValueExpr(iterator.field, int), *flat_index] + args = [ValueExpr(iterator.field, iterator.dtype), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") @@ -849,7 +847,7 @@ def _visit_reduce(self, node: itir.FunCall): p.apply_pass(lambda_context.body, {}) input_memlets = [ - create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) ] output_memlet = dace.Memlet.simple(result_name, "0") @@ -928,7 +926,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = create_memlet_at(result_access.data, ("0",)) + memlet = dace.Memlet.simple(result_access.data, "0") self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..7e6fe13ac7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,10 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any +from typing import Any, Sequence import dace +from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -49,7 +50,7 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) def create_memlet_at(source_identifier: str, index: tuple[str, ...]): @@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]): return dace.Memlet(data=source_identifier, subset=subset) +def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: + return sorted(enumerate(dims), key=lambda v: v[1].value) + + def map_nested_sdfg_symbols( parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] ) -> dict[str, str]: From 45a6e6d10939b14580eb9c1c243f5779ccd3dc44 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Oct 2023 11:54:59 +0200 Subject: [PATCH 21/67] feat[next]: Add DaCe support for field arguments with domain offset (#1348) This PR adds support in DaCe backend for field arguments with domain offset. This feature is required by icon4py stencils. --- .../runners/dace_iterator/__init__.py | 32 +++++++++++++++---- .../runners/dace_iterator/itir_to_sdfg.py | 16 +++++++--- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 25609b1035..18e257d462 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -20,7 +20,7 @@ from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir -from gt4py.next import common +from gt4py.next.common import Domain, UnitRange, is_field from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache @@ -28,7 +28,12 @@ from gt4py.next.type_system import type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import connectivity_identifier, filter_neighbor_tables +from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims + + +def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: + sorted_dims = get_sorted_dims(domain.dims) + return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] """ Default build configuration in DaCe backend """ @@ -40,10 +45,10 @@ def convert_arg(arg: Any): - if common.is_field(arg): - sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) + if is_field(arg): + sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) - dim_indices = [dim[0] for dim in sorted_dims] + dim_indices = [dim_index for dim_index, _ in sorted_dims] assert isinstance(arg.ndarray, np.ndarray) return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -79,6 +84,17 @@ def get_shape_args( } +def get_offset_args( + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] +) -> Mapping[str, int]: + return { + str(sym): -drange.start + for param, arg in zip(params, args) + if is_field(arg) + for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + } + + def get_stride_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: @@ -163,7 +179,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) all_args = { **dace_args, @@ -171,7 +188,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: **dace_shapes, **dace_conn_shapes, **dace_strides, - **dace_conn_stirdes, + **dace_conn_strides, + **dace_offsets, } expected_args = { key: value diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 2b4ad721b8..7017815688 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -107,12 +107,17 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec): + def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + offset = ( + [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + if has_offset + else None + ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -136,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -287,8 +292,8 @@ def visit_StencilClosure( closure_sdfg.add_array( nsdfg_output_name, dtype=output_descriptor.dtype, - shape=(array_table[output_name].shape[scan_dim_index],), - strides=(array_table[output_name].strides[scan_dim_index],), + shape=(output_descriptor.shape[scan_dim_index],), + strides=(output_descriptor.strides[scan_dim_index],), transient=True, ) @@ -528,6 +533,7 @@ def _visit_scan_stencil_closure( data_name, shape=(array_table[node.output.id].shape[scan_dim_index],), strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), dtype=array_table[node.output.id].dtype, ) lambda_state.add_memlet_path( From 90eea30a000d8b1780bb65457a73e9839e5b3e93 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 17 Oct 2023 10:11:49 +0200 Subject: [PATCH 22/67] feat[next]: DaCe support for neighbor strided offset (#1344) This PR adds support for neighbor strided offset in DaCe backend, another ITIR feature needed by icon4py stencils. The design choice has been to extract max_neighbors from offset_provider at compile-time and hard-code it in the SDFG. Additionally, the hash function to check the SDFG binary cache is modified to use SHA256, in order to reduce collision risk. --- .../runners/dace_iterator/__init__.py | 46 ++++++++++++++----- .../runners/dace_iterator/itir_to_tasklet.py | 35 +++++++++----- tests/next_tests/exclusion_matrices.py | 12 +++-- 3 files changed, 67 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 18e257d462..1c1bed9c5e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -11,8 +11,8 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any, Mapping, Sequence +import hashlib +from typing import Any, Mapping, Optional, Sequence import dace import numpy as np @@ -20,12 +20,12 @@ from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir -from gt4py.next.common import Domain, UnitRange, is_field -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import Dimension, Domain, UnitRange, is_field +from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.processor_interface import program_executor -from gt4py.next.type_system import type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims @@ -111,12 +111,34 @@ def get_stride_args( return stride_args -_build_cache_cpu: dict[int, CompiledSDFG] = {} -_build_cache_gpu: dict[int, CompiledSDFG] = {} - - -def get_cache_id(*cache_args) -> int: - return sum([hash(str(arg)) for arg in cache_args]) +_build_cache_cpu: dict[str, CompiledSDFG] = {} +_build_cache_gpu: dict[str, CompiledSDFG] = {} + + +def get_cache_id( + program: itir.FencilDefinition, + arg_types: Sequence[ts.TypeSpec], + column_axis: Optional[Dimension], + offset_provider: Mapping[str, Any], +) -> str: + max_neighbors = [ + (k, v.max_neighbors) + for k, v in offset_provider.items() + if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + ] + cache_id_args = [ + str(arg) + for arg in ( + program, + *arg_types, + column_axis, + *max_neighbors, + ) + ] + m = hashlib.sha256() + for s in cache_id_args: + m.update(s.encode()) + return m.hexdigest() @program_executor @@ -133,7 +155,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: arg_types = [type_translation.from_value(arg) for arg in args] neighbor_tables = filter_neighbor_tables(offset_provider) - cache_id = get_cache_id(program, *arg_types, column_axis) + cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d3bfb5ff0e..6acc39c50a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -23,7 +23,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen -from gt4py.next import Dimension, type_inference as next_typing +from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import FunCall, Lambda @@ -700,18 +700,31 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: element = tail[1].value assert isinstance(element, int) - table: NeighborTableOffsetProvider = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value + if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): + table = self.offset_provider[offset] + shifted_dim = table.origin_axis.value + target_dim = table.neighbor_axis.value - conn = self.context.state.add_access(connectivity_identifier(offset)) + conn = self.context.state.add_access(connectivity_identifier(offset)) + + args = [ + ValueExpr(conn, table.table.dtype), + ValueExpr(iterator.indices[shifted_dim], dace.int64), + ] + + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{internals[1]}, {element}]" + else: + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, StridedNeighborOffsetProvider) + + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value + offset_value = iterator.indices[shifted_dim] + args = [ValueExpr(offset_value, dace.int64)] + internals = [f"{offset_value.data}_v"] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" - args = [ - ValueExpr(conn, table.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" shifted_value = self.add_expr_tasklet( list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" )[0].value diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 27ccb29095..98ac9352c3 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -61,7 +61,6 @@ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ] #: Skip matrix, contains for each backend processor a list of tuples with following fields: @@ -81,11 +80,18 @@ (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ], - GTFN_CPU: GTFN_SKIP_TEST_LIST, - GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + GTFN_CPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], + GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], GTFN_FORMAT_SOURCECODE: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), From f96ead5fbb0b7d9edfbe6c6e9c67a08ddb2ee50a Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 17 Oct 2023 14:19:33 +0200 Subject: [PATCH 23/67] fix[next]: DaCe field addressing in builtin_neighbors (#1349) Bugfix in DaCe backend to make field addressing in builtin_neighbors consistent with the canonical representation (field dimensions alphabetically sorted). --- .../runners/dace_iterator/itir_to_tasklet.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 6acc39c50a..610698646a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -243,10 +243,8 @@ def builtin_neighbors( ) # select full shape only in the neighbor-axis dimension field_subset = [ - f"0:{sdfg.arrays[iterator.field.data].shape[idx]}" - if dim == table.neighbor_axis.value - else f"i_{dim}" - for idx, dim in enumerate(iterator.dimensions) + f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) ] state.add_memlet_path( iterator.field, @@ -575,6 +573,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return iterator args: list[ValueExpr] + sorted_dims = sorted(iterator.dimensions) if self.context.reduce_limit: # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing result_name = unique_var_name() @@ -596,7 +595,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # if dim is not found in iterator indices, we take the neighbor index over the reduction domain flat_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in sorted(iterator.dimensions) + for dim in sorted_dims ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices @@ -629,11 +628,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] else: - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims ] - args = [ValueExpr(iterator.field, iterator.dtype), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") From d11246ec828acd6b904dadb80ac535ebd21b5359 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 20 Oct 2023 10:37:57 +0200 Subject: [PATCH 24/67] feat[next]: DaCe support for tuple returns (#1343) This PR adds support in DaCe backends for closures with tuple returns. The motivation is that tuple returns are used in icon4py stencils, although the internal expressions do not operate on tuples. Tuples are just a mean to aggregate multiple-outputs from one stencil. For that reason, this PR does not contain support for scan or conditional expressions with tuples. --- .../runners/dace_iterator/itir_to_sdfg.py | 200 +++++++++--------- .../runners/dace_iterator/itir_to_tasklet.py | 45 ++-- .../runners/dace_iterator/utility.py | 10 +- .../ffront_tests/test_execution.py | 5 +- .../ffront_tests/test_program.py | 3 - .../iterator_tests/test_tuple.py | 6 +- .../iterator_tests/test_column_stencil.py | 2 +- 7 files changed, 136 insertions(+), 135 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 7017815688..580486aa4a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ create_memlet_at, create_memlet_full, filter_neighbor_tables, + flatten_list, get_sorted_dims, map_nested_sdfg_symbols, unique_var_name, @@ -124,6 +125,13 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset raise NotImplementedError() self.storage_types[name] = type_ + def get_output_nodes( + self, closure: itir.StencilClosure, context: Context + ) -> dict[str, dace.nodes.AccessNode]: + translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types) + output_nodes = flatten_list(translator.visit(closure.output)) + return {node.value.data: node.value for node in output_nodes} + def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) last_state = program_sdfg.add_state("program_entry") @@ -145,50 +153,29 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Create a nested SDFG for all stencil closures. for closure in node.closures: - assert isinstance(closure.output, itir.SymRef) - - # filter out arguments with scalar type, because they are passed as symbols - input_names = [ - str(inp.id) - for inp in closure.inputs - if isinstance(self.storage_types[inp.id], ts.FieldType) - ] - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_names = [str(closure.output.id)] - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg = self.visit(closure, array_table=program_sdfg.arrays) + closure_sdfg, input_names, output_names = self.visit( + closure, array_table=program_sdfg.arrays + ) # Create a new state for the closure. last_state = program_sdfg.add_state_after(last_state) # Create memlets to transfer the program parameters - input_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names - ] - connectivity_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in connectivity_names - ] - output_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names - ] - - input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) + input_mapping = { + name: create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names } output_mapping = { - param: arg_memlet for param, arg_memlet in zip(output_names, output_memlets) + name: create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, array_mapping) + symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) # Insert the closure's SDFG as a nested SDFG of the program. nsdfg_node = last_state.add_nested_sdfg( sdfg=closure_sdfg, parent=program_sdfg, - inputs=set(input_names) | set(connectivity_names), + inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, ) @@ -198,49 +185,78 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): access_node = last_state.add_access(inner_name) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - for inner_name, memlet in connectivity_mapping.items(): - access_node = last_state.add_access(inner_name) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - for inner_name, memlet in output_mapping.items(): access_node = last_state.add_access(inner_name) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) + program_sdfg.validate() return program_sdfg def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> dace.SDFG: + ) -> tuple[dace.SDFG, list[str], list[str]]: assert ItirToSDFG._check_no_lifts(node) assert ItirToSDFG._check_shift_offsets_are_literals(node) - assert isinstance(node.output, itir.SymRef) - - neighbor_tables = filter_neighbor_tables(self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_name = str(node.output.id) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - # Add DaCe arrays for inputs, output and connectivities to closure SDFG. - for name in [*input_names, *conn_names, output_name]: - assert name not in closure_sdfg.arrays or (name in input_names and name == output_name) + program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) + neighbor_tables = filter_neighbor_tables(self.offset_provider) + + input_names = [str(inp.id) for inp in node.inputs] + conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + + output_nodes = self.get_output_nodes(node, closure_ctx) + output_names = [k for k, _ in output_nodes.items()] + + # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. + input_transients_mapping = {} + for name in [*input_names, *conn_names, *output_names]: if name in closure_sdfg.arrays: - # in/out parameter, container already added for in parameter - continue - if isinstance(self.storage_types[name], ts.FieldType): + assert name in input_names and name in output_names + # In case of closures with in/out fields, there is risk of race condition + # between read/write access nodes in the (asynchronous) map tasklet. + transient_name = unique_var_name() + closure_sdfg.add_array( + transient_name, + shape=array_table[name].shape, + strides=array_table[name].strides, + dtype=array_table[name].dtype, + transient=True, + ) + closure_init_state.add_nedge( + closure_init_state.add_access(name), + closure_init_state.add_access(transient_name), + create_memlet_full(name, closure_sdfg.arrays[name]), + ) + input_transients_mapping[name] = transient_name + elif isinstance(self.storage_types[name], ts.FieldType): closure_sdfg.add_array( name, shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, ) + else: + assert isinstance(self.storage_types[name], ts.ScalarType) - # Get output domain of the closure - program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + input_field_names = [ + input_name + for input_name in input_names + if isinstance(self.storage_types[input_name], ts.FieldType) + ] + + # Closure outputs should all be fields + assert all( + isinstance(self.storage_types[output_name], ts.FieldType) + for output_name in output_names + ) + + # Update symbol table and get output domain of the closure for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): if name in input_names: @@ -258,73 +274,64 @@ def visit_StencilClosure( program_arg_syms[name] = value else: program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_)) - domain_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, domain_ctx) + closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters input_access_names = [ - input_name - if isinstance(self.storage_types[input_name], ts.FieldType) + input_transients_mapping[input_name] + if input_name in input_transients_mapping + else input_name + if input_name in input_field_names else cast(ValueExpr, program_arg_syms[input_name]).value.data for input_name in input_names ] input_memlets = [ create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names ] - conn_memlet = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] + conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] - transient_to_arg_name_mapping = {} # create and write to transient that is then copied back to actual output array to avoid aliasing of # same memory in nested SDFG with different names - nsdfg_output_name = unique_var_name() - output_descriptor = closure_sdfg.arrays[output_name] - transient_to_arg_name_mapping[nsdfg_output_name] = output_name + output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} # scan operator should always be the first function call in a closure if is_scan(node.stencil): + assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" + transient_name, output_name = next(iter(output_connectors_mapping.items())) + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, nsdfg_output_name + node, closure_sdfg.arrays, closure_domain, transient_name ) - results = [nsdfg_output_name] + results = [transient_name] _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] output_subset = f"{scan_lb.value}:{scan_ub.value}" - closure_sdfg.add_array( - nsdfg_output_name, - dtype=output_descriptor.dtype, - shape=(output_descriptor.shape[scan_dim_index],), - strides=(output_descriptor.strides[scan_dim_index],), - transient=True, - ) - - output_memlet = create_memlet_at( - output_name, - tuple( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{output_descriptor.shape[scan_dim_index]}" - for dim, _ in closure_domain - ), - ) + output_memlets = [ + create_memlet_at( + output_name, + tuple( + f"i_{dim}" + if f"i_{dim}" in map_ranges + else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + for dim, _ in closure_domain + ), + ) + ] else: nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) - assert len(results) == 1 output_subset = "0" - closure_sdfg.add_scalar( - nsdfg_output_name, - dtype=output_descriptor.dtype, - transient=True, - ) - - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) + output_memlets = [ + create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) + for output_name in output_connectors_mapping.values() + ] input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} - conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlet)} + output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)} + conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)} array_mapping = {**input_mapping, **conn_mapping} symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) @@ -336,11 +343,12 @@ def visit_StencilClosure( inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, + output_nodes=output_nodes, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data - if memlet.data not in transient_to_arg_name_mapping: + if memlet.data not in output_connectors_mapping: continue transient_access = closure_state.add_access(memlet.data) closure_state.add_edge( @@ -355,21 +363,9 @@ def visit_StencilClosure( ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) - access_nodes[memlet.data].data = transient_to_arg_name_mapping[memlet.data] - - for _, (lb, ub) in closure_domain: - for b in lb, ub: - if isinstance(b, SymbolExpr): - continue - map_entry.add_in_connector(b.value.data) - closure_state.add_edge( - b.value, - None, - map_entry, - b.value.data, - dace.Memlet.simple(b.value.data, "0"), - ) - return closure_sdfg + access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] + + return closure_sdfg, input_field_names + conn_names, output_names def _visit_scan_stencil_closure( self, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 610698646a..b28703feef 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -36,6 +36,7 @@ connectivity_identifier, create_memlet_full, filter_neighbor_tables, + flatten_list, map_nested_sdfg_symbols, unique_name, unique_var_name, @@ -423,32 +424,36 @@ def visit_Lambda( context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) # Translate the function's body - result: ValueExpr | SymbolExpr = self.visit(node.expr)[0] - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - if isinstance(result, ValueExpr): - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result.dtype, transient=True) - result_access = self.context.state.add_access(result_name) - self.context.state.add_edge( - result.value, - None, - result_access, - None, - # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), - ) - result = ValueExpr(value=result_access, dtype=result.dtype) - else: - result = self.add_expr_tasklet([], result.value, result.dtype, "forward")[0] - self.context.body.arrays[result.value.data].transient = False - self.context = prev_context + results: list[ValueExpr] = [] + # We are flattening the returned list of value expressions because the multiple outputs of a lamda + # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + for expr in flatten_list(self.visit(node.expr)): + if isinstance(expr, ValueExpr): + result_name = unique_var_name() + self.context.body.add_scalar(result_name, expr.dtype, transient=True) + result_access = self.context.state.add_access(result_name) + self.context.state.add_edge( + expr.value, + None, + result_access, + None, + # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), + ) + result = ValueExpr(value=result_access, dtype=expr.dtype) + else: + # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors + result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + self.context.body.arrays[result.value.data].transient = False + results.append(result) + self.context = prev_context for node in context.state.nodes(): if isinstance(node, dace.nodes.AccessNode): if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0: context.state.remove_node(node) - return context, inputs, [result] + return context, inputs, results def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: if node.id not in self.context.symbol_map: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 7e6fe13ac7..1fdd022a49 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import itertools from typing import Any, Sequence import dace @@ -166,3 +166,11 @@ def unique_name(prefix): def unique_var_name(): return unique_name("__var") + + +def flatten_list(node_list: list[Any]) -> list[Any]: + return list( + itertools.chain.from_iterable( + [flatten_list(e) if e.__class__ == list else [e] for e in node_list] + ) + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 865950eeab..61b34460ef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -159,7 +159,6 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:]) -@pytest.mark.uses_tuple_returns def test_tuples(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatField: @@ -400,7 +399,6 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD assert np.allclose(out, ref) -@pytest.mark.uses_tuple_returns def test_nested_tuple_return(cartesian_case): @gtx.field_operator def pack_tuple( @@ -476,7 +474,7 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_constant_fields def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): @gtx.field_operator def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: @@ -840,7 +838,6 @@ def program_domain( ) -@pytest.mark.uses_tuple_returns def test_domain_tuple(cartesian_case): @gtx.field_operator def fieldop_domain_tuple( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index f489126fa7..d86bc21679 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -128,7 +128,6 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField): ) -@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside(cartesian_case): @gtx.field_operator def pack_tuple( @@ -155,7 +154,6 @@ def prog( assert np.allclose((a, b), (out_a, out_b)) -@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): @gtx.field_operator def pack_tuple( @@ -183,7 +181,6 @@ def prog( assert out_a[0] == 0 and out_b[0] == 0 -@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_nested(cartesian_case): @gtx.field_operator def pack_tuple( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index bd5a717bb2..67b439507c 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -148,7 +148,6 @@ def stencil(inp1, inp2, inp3, inp4): "stencil", [tuple_output1, tuple_output2], ) -@pytest.mark.uses_tuple_returns def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor @@ -194,7 +193,6 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): assert np.allclose(inp2, out2) -@pytest.mark.uses_tuple_returns def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): program_processor, validate = program_processor @@ -288,7 +286,7 @@ def tuple_input(inp): return tuple_get(0, inp_deref) + tuple_get(1, inp_deref) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_tuple_args def test_tuple_field_input(program_processor): program_processor, validate = program_processor @@ -348,7 +346,7 @@ def tuple_tuple_input(inp): ) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_tuple_args def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 41d6c8f0f9..04cf8c6f9c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -149,7 +149,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): ), ], ) -@pytest.mark.uses_tuple_returns +@pytest.mark.uses_tuple_args def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor From af7ff8abd2588e79116faeeb212249c673a779a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 24 Oct 2023 10:53:42 +0200 Subject: [PATCH 25/67] feature[next] GPU backend from Python (#1325) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add support for gtfn cuda backend * reconcile new code with type hints by relaxing type hints. * add ADR and todos for library x builsys matrix design * add cpu/gpu tox variants for next, update ci files --------- Co-authored-by: Rico Häuselmann --- .github/workflows/test-next.yml | 10 +- ci/cscs-ci.yml | 1 + .../ADRs/0009-Compiled-Backend-Integration.md | 2 +- ...016-Multiple-Backends-and-Build-Systems.md | 118 ++++++++++++++++++ src/gt4py/next/iterator/embedded.py | 8 +- src/gt4py/next/otf/binding/nanobind.py | 18 +-- .../otf/compilation/build_systems/cmake.py | 19 +-- .../compilation/build_systems/cmake_lists.py | 35 ++++-- .../compilation/build_systems/compiledb.py | 69 +++++----- src/gt4py/next/otf/compilation/compiler.py | 2 +- src/gt4py/next/otf/languages.py | 10 +- src/gt4py/next/otf/recipes.py | 18 +-- src/gt4py/next/otf/step_types.py | 5 +- .../codegens/gtfn/codegen.py | 15 +++ .../codegens/gtfn/gtfn_module.py | 107 +++++++++++++--- .../otf_compile_executor.py | 8 +- .../runners/{gtfn_cpu.py => gtfn.py} | 52 +++++--- .../feature_tests/ffront_tests/__init__.py | 13 ++ .../ffront_tests/ffront_test_utils.py | 10 +- .../ffront_tests/test_arg_call_interface.py | 15 +-- .../ffront_tests/test_execution.py | 25 ++-- .../ffront_tests/test_gpu_backend.py | 43 +++++++ .../ffront_tests/test_gt4py_builtins.py | 11 +- .../ffront_tests/test_math_unary_builtins.py | 8 +- .../iterator_tests/test_builtins.py | 2 +- .../ffront_tests/test_icon_like_scan.py | 14 +-- .../iterator_tests/test_anton_toy.py | 8 +- .../iterator_tests/test_fvm_nabla.py | 3 +- .../iterator_tests/test_hdiff.py | 8 +- .../iterator_tests/test_vertical_advection.py | 16 ++- .../test_with_toy_connectivity.py | 10 +- .../otf_tests/test_gtfn_workflow.py | 4 +- tests/next_tests/unit_tests/conftest.py | 12 +- .../build_systems_tests/conftest.py | 2 +- .../gtfn_tests/test_gtfn_module.py | 4 +- tox.ini | 8 +- 36 files changed, 507 insertions(+), 206 deletions(-) create mode 100644 docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md rename src/gt4py/next/program_processors/runners/{gtfn_cpu.py => gtfn.py} (76%) create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 5baeb6acef..52f8c25386 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -57,13 +57,13 @@ jobs: run: | pyversion=${{ matrix.python-version }} pyversion_no_dot=${pyversion//./} - tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-env-factor }} - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}.json + tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-env-factor }}-cpu + # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json # - name: Upload coverage.json artifact # uses: actions/upload-artifact@v3 # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }} - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}.json + # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu + # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json # - name: Gather info # run: | # echo ${{ github.ref_type }} >> info.txt @@ -76,5 +76,5 @@ jobs: # - name: Upload info artifact # uses: actions/upload-artifact@v3 # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }} + # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu # path: info.txt diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 3dc38bcd97..971a3cfc35 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -117,3 +117,4 @@ test py310: - SUBPACKAGE: eve - SUBPACKAGE: next VARIANT: [-nomesh, -atlas] + SUBVARIANT: [-cuda11x, -cpu] diff --git a/docs/development/ADRs/0009-Compiled-Backend-Integration.md b/docs/development/ADRs/0009-Compiled-Backend-Integration.md index 273f954438..27c2f0c73c 100644 --- a/docs/development/ADRs/0009-Compiled-Backend-Integration.md +++ b/docs/development/ADRs/0009-Compiled-Backend-Integration.md @@ -159,7 +159,7 @@ Compiled backends may generate code which depends on libraries and tools written 1. can be installed with `pip` (from `PyPI` or another source) automatically. 2. can not be installed with `pip` and not commonly found on HPC machines. -3. libraries and tools which are left to the user to install and make discoverable: `pybind11`, C++ compilers +3. libraries and tools which are left to the user to install and make discoverable: `boost`, C++ compilers Category 1 are made dependencies of `GT4Py`. Examples include `pybind11`, `cmake`, `ninja`. diff --git a/docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md b/docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md new file mode 100644 index 0000000000..ac84903514 --- /dev/null +++ b/docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md @@ -0,0 +1,118 @@ +--- +tags: [backend, gridtools, bindings, libraries, otf] +--- + +# Support for Multiple Backends, Build Systems and Libraries + +- **Status**: valid +- **Authors**: Rico Häuselmann (@DropD) +- **Created**: 2023-10-11 +- **Updated**: 2023-10-11 + +In the process of enabling CUDA for the GTFN backend, we encountered a potential support matrix of build systems x target language libraries. The current design requires build systems about all the libraries they can be used with. We decided that the matrix is too small for now and to not revisit the existing design yet. + +## Context + +ADRs [0009](0009-Compiled_Backend_Integration.md), [0011](0011-On_The_Fly_Compilation.md) and [0012](0012-GridTools_Cpp_OTF_Steps.md) detail the design decisions around what is loosely referred as "gt4py.next backends". In summary the goals are: + +- extensibility + - adding backends should not require changing existing code + - adding / modifying backend modules like build systems / compilers should not be blocked by assumptions in other modules. +- modularity + - increase the chance that two different backends (for example GTFN and another C++ backend) can share code. + +Therefore the concerns of generating code in the target language, generating python bindings in the target language and of building (compiling) the generated code are separated it code generator, bindings generator and compile step / build system. The compile step is written to be build system agnostic. + +There is one category that connects all these concerns: libraries written in the target language and used in generated / bindings code. + +Current design: + +```mermaid +graph LR + +gtgen("GTFN code generator (C++/Cuda)") --> |GridTools::fn_naive| Compiler +gtgen("GTFN code generator (C++/Cuda)") --> |GridTools::fn_gpu| Compiler +nb("nanobind bindings generator") --> |nanobind| Compiler +Compiler --> CMakeProject --> CMakeListsGenerator +Compiler --> CompiledbProject --> CMakeListsGenerator +``` + +The current design contains two mappings: + +- library name -> CMake `find_package()` call +- library name -> CMake target name + +and the gridtools cpu/gpu link targets are differentiated by internally separating between two fictitious "gridtools_cpu" and "gridtools_gpu" libraries. + +## concerns + +### Usage + +The "gridtools_cpu" and "gridtools_gpu" fake library names add to the learning curve for this part of the code. Reuse of the existing components might require this knowledge. + +### Scalability + +Adding a new backend using the existing build systems but relying on different libraries has to modify existing build system components (at the very least CMakeListsGenerator). + +### Separation of concerns + +It makes more sense to separate the concerns of how to generate a valid build system configuration and how to use a particular library in a particular build system than to mix the two. + +## Decision + +Currently the code overhead is in the tens of lines, and there are no concrete plans to add more compiled backends or different build systems. Therefore we decide to keep the current design for now but to redesign as soon as the matrix grows. +To this end ToDo comments are added in the relevant places + +## Consequences + +Initial GTFN gpu support will not be blocked by design work. + +## Alternatives Considered + +### Push build system support to the LibraryDependency instance + +``` +#src/gt4py/next/otf/binding/interface.py + +... +class LibraryDependency: + name: str + version: str + link_targets: list[str] + include_headers: list[str] +``` + +- Simple, choice is made at code generator level, where the knowledge should be +- Interface might not suit every build system +- Up to the implementer to make the logic for choosing reusable (or not) + +### Create additional data structures to properly separate concerns + +``` +class BuildSystemConfig: + device_type: core_defs.DeviceType + ... + + +class LibraryAdaptor: + library: LibraryDependency + build_system: CMakeProject + + def config_phase(self, config: BuildSystemConfig) -> str: + import gridtools_cpp + cmake_dir = gridtools_cpp.get_cmake_dir() + + return f"find_package(... {cmake_dir} ... )" + +def build_phase(self, config: BuildSystemConfig) -> str: + return "" # header only library + +def link_phase(self, main_target_name: str, config: BuildSystemConfig) -> str: + return f"target_link_libraries({main_target_name} ...)" +``` + +- More general and fully extensible, adaptors can be added for any required library / build system combination without touching existing code (depending on the registering mechanism). +- More likely to be reusable as choices are explicit and can be overridden separately by sub classing. +- More design work required. Open questions: + - Design the interface to work with any build system + - How to register adaptors? entry points? global dictionary? diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 0edea35cf5..3d159eaae7 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -685,7 +685,7 @@ def _single_vertical_idx( indices: NamedFieldIndices, column_axis: Tag, column_index: common.IntIndex ) -> NamedFieldIndices: transformed = { - axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis` + axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis` # fmt: off for axis, index in indices.items() } return transformed @@ -1050,7 +1050,7 @@ def __gt_origin__(self) -> tuple[int, ...]: return (0,) @classmethod - def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype + def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype # fmt: off raise NotImplementedError() @property @@ -1070,7 +1070,7 @@ def remap(self, index_field: common.Field) -> common.Field: raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32: - if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code + if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off d, r = item[0] assert d == self._dimension assert isinstance(r, int) @@ -1156,7 +1156,7 @@ def __gt_origin__(self) -> tuple[int, ...]: return tuple() @classmethod - def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype + def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype # fmt: off raise NotImplementedError() @property diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 9dccddc012..5d54512bd0 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Any, Sequence, Union +from typing import Any, Sequence, TypeVar, Union import gt4py.eve as eve from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator @@ -26,6 +26,9 @@ from gt4py.next.type_system import type_info as ti, type_specifications as ts +SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL, covariant=True) + + class Expr(eve.Node): pass @@ -191,8 +194,8 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | CompositeS def create_bindings( - program_source: stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings], -) -> stages.BindingSource[languages.Cpp, languages.Python]: + program_source: stages.ProgramSource[SrcL, languages.LanguageWithHeaderFilesSettings], +) -> stages.BindingSource[SrcL, languages.Python]: """ Generate Python bindings through which a C++ function can be called. @@ -201,7 +204,7 @@ def create_bindings( program_source The program source for which the bindings are created """ - if program_source.language is not languages.Cpp: + if program_source.language not in [languages.Cpp, languages.Cuda]: raise ValueError( f"Can only create bindings for C++ program sources, received {program_source.language}." ) @@ -221,7 +224,6 @@ def create_bindings( "gridtools/common/tuple_util.hpp", "gridtools/fn/unstructured.hpp", "gridtools/fn/cartesian.hpp", - "gridtools/fn/backend/naive.hpp", "gridtools/storage/adapter/nanobind_adapter.hpp", ], wrapper=WrapperFunction( @@ -266,8 +268,6 @@ def create_bindings( @workflow.make_step def bind_source( - inp: stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings], -) -> stages.CompilableSource[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python -]: + inp: stages.ProgramSource[SrcL, languages.LanguageWithHeaderFilesSettings], +) -> stages.CompilableSource[SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python]: return stages.CompilableSource(program_source=inp, binding_source=create_bindings(inp)) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index b281fde7b5..3d36f5d985 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -38,7 +38,7 @@ def _generate_next_value_(name, start, count, last_values): @dataclasses.dataclass class CMakeFactory( compiler.BuildSystemProjectGenerator[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python + languages.Cpp | languages.Cuda, languages.LanguageWithHeaderFilesSettings, languages.Python ] ): """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" @@ -50,7 +50,7 @@ class CMakeFactory( def __call__( self, source: stages.CompilableSource[ - languages.Cpp, + languages.Cpp | languages.Cuda, languages.LanguageWithHeaderFilesSettings, languages.Python, ], @@ -63,16 +63,21 @@ def __call__( name = source.program_source.entry_point.name header_name = f"{name}.{source.program_source.language_settings.header_extension}" bindings_name = f"{name}_bindings.{source.program_source.language_settings.file_extension}" + cmake_languages = [cmake_lists.Language(name="CXX")] + if source.program_source.language is languages.Cuda: + cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")] + cmake_lists_src = cmake_lists.generate_cmakelists_source( + name, + source.library_deps, + [header_name, bindings_name], + languages=cmake_languages, + ) return CMakeProject( root_path=cache.get_cache_folder(source, cache_strategy), source_files={ header_name: source.program_source.source_code, bindings_name: source.binding_source.source_code, - "CMakeLists.txt": cmake_lists.generate_cmakelists_source( - name, - source.library_deps, - [header_name, bindings_name], - ), + "CMakeLists.txt": cmake_lists_src, }, program_name=name, generator_name=self.cmake_generator_name, diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index ef222341e3..5ea4ba0519 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -30,22 +30,31 @@ class LinkDependency(eve.Node): target: str +class Language(eve.Node): + name: str + + class CMakeListsFile(eve.Node): project_name: str find_deps: Sequence[FindDependency] link_deps: Sequence[LinkDependency] source_names: Sequence[str] bin_output_suffix: str + languages: Sequence[Language] class CMakeListsGenerator(eve.codegen.TemplatedGenerator): CMakeListsFile = as_jinja( """ - project({{project_name}}) cmake_minimum_required(VERSION 3.20.0) + project({{project_name}}) + # Languages - enable_language(CXX) + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 60) + endif() + {{"\\n".join(languages)}} # Paths list(APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}) @@ -77,18 +86,17 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator): ) def visit_FindDependency(self, dep: FindDependency): + # TODO(ricoh): do not add more libraries here + # and do not use this design in a new build system. + # Instead, design this to be extensible (refer to ADR-0016). match dep.name: - case "pybind11": - import pybind11 - - return f"find_package(pybind11 CONFIG REQUIRED PATHS {pybind11.get_cmake_dir()} NO_DEFAULT_PATH)" case "nanobind": import nanobind py = "find_package(Python COMPONENTS Interpreter Development REQUIRED)" nb = f"find_package(nanobind CONFIG REQUIRED PATHS {nanobind.cmake_dir()} NO_DEFAULT_PATHS)" return py + "\n" + nb - case "gridtools": + case "gridtools_cpu" | "gridtools_gpu": import gridtools_cpp return f"find_package(GridTools REQUIRED PATHS {gridtools_cpp.get_cmake_dir()} NO_DEFAULT_PATH)" @@ -96,13 +104,16 @@ def visit_FindDependency(self, dep: FindDependency): raise ValueError("Library {name} is not supported".format(name=dep.name)) def visit_LinkDependency(self, dep: LinkDependency): + # TODO(ricoh): do not add more libraries here + # and do not use this design in a new build system. + # Instead, design this to be extensible (refer to ADR-0016). match dep.name: - case "pybind11": - lib_name = "pybind11::module" case "nanobind": lib_name = "nanobind-static" - case "gridtools": + case "gridtools_cpu": lib_name = "GridTools::fn_naive" + case "gridtools_gpu": + lib_name = "GridTools::fn_gpu" case _: raise ValueError("Library {name} is not supported".format(name=dep.name)) @@ -118,11 +129,14 @@ def visit_LinkDependency(self, dep: LinkDependency): lnk = f"target_link_libraries({dep.target} PUBLIC {lib_name})" return cfg + "\n" + lnk + Language = as_jinja("enable_language({{name}})") + def generate_cmakelists_source( project_name: str, dependencies: tuple[interface.LibraryDependency, ...], source_names: Sequence[str], + languages: Sequence[Language] = (Language(name="CXX"),), ) -> str: """ Generate CMakeLists file contents. @@ -135,5 +149,6 @@ def generate_cmakelists_source( link_deps=[LinkDependency(name=d.name, target=project_name) for d in dependencies], source_names=source_names, bin_output_suffix=common.python_module_suffix(), + languages=languages, ) return CMakeListsGenerator.apply(cmakelists_file) diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 34f2f85081..84a69859c0 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -20,7 +20,7 @@ import re import shutil import subprocess -from typing import Optional +from typing import Optional, TypeVar from gt4py.next.otf import languages, stages from gt4py.next.otf.binding import interface @@ -28,10 +28,13 @@ from gt4py.next.otf.compilation.build_systems import cmake, cmake_lists +SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL) + + @dataclasses.dataclass class CompiledbFactory( compiler.BuildSystemProjectGenerator[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python + SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python ] ): """ @@ -48,7 +51,7 @@ class CompiledbFactory( def __call__( self, source: stages.CompilableSource[ - languages.Cpp, + SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python, ], @@ -66,6 +69,8 @@ def __call__( deps=source.library_deps, build_type=self.cmake_build_type, cmake_flags=self.cmake_extra_flags or [], + language=source.program_source.language, + language_settings=source.program_source.language_settings, ) if self.renew_compiledb or not ( @@ -92,9 +97,7 @@ def __call__( @dataclasses.dataclass() class CompiledbProject( - stages.BuildSystemProject[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python - ] + stages.BuildSystemProject[SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python] ): """ Compiledb build system for gt4py programs. @@ -113,18 +116,21 @@ class CompiledbProject( compile_commands_cache: pathlib.Path bindings_file_name: str - def build(self): + def build(self) -> None: self._write_files() - if build_data.read_data(self.root_path).status < build_data.BuildStatus.CONFIGURED: + current_data = build_data.read_data(self.root_path) + if current_data is None or current_data.status < build_data.BuildStatus.CONFIGURED: self._run_config() + current_data = build_data.read_data(self.root_path) # update after config if ( - build_data.BuildStatus.CONFIGURED - <= build_data.read_data(self.root_path).status + current_data is not None + and build_data.BuildStatus.CONFIGURED + <= current_data.status < build_data.BuildStatus.COMPILED ): self._run_build() - def _write_files(self): + def _write_files(self) -> None: def ignore_not_libraries(folder: str, children: list[str]) -> list[str]: pattern = r"((lib.*\.a)|(.*\.lib))" libraries = [child for child in children if re.match(pattern, child)] @@ -151,7 +157,7 @@ def ignore_not_libraries(folder: str, children: list[str]) -> list[str]: path=self.root_path, ) - def _run_config(self): + def _run_config(self) -> None: compile_db = json.loads(self.compile_commands_cache.read_text()) (self.root_path / "build").mkdir(exist_ok=True) @@ -176,7 +182,7 @@ def _run_config(self): self.root_path, ) - def _run_build(self): + def _run_build(self) -> None: logfile = self.root_path / "log_build.txt" compile_db = json.loads((self.root_path / "compile_commands.json").read_text()) assert compile_db @@ -212,19 +218,16 @@ def _cc_prototype_program_source( deps: tuple[interface.LibraryDependency, ...], build_type: cmake.BuildType, cmake_flags: list[str], + language: type[SrcL], + language_settings: languages.LanguageWithHeaderFilesSettings, ) -> stages.ProgramSource: name = _cc_prototype_program_name(deps, build_type.value, cmake_flags) return stages.ProgramSource( entry_point=interface.Function(name=name, parameters=()), source_code="", library_deps=deps, - language=languages.Cpp, - language_settings=languages.LanguageWithHeaderFilesSettings( - formatter_key="", - formatter_style=None, - file_extension="", - header_extension="", - ), + language=language, + language_settings=language_settings, ) @@ -251,16 +254,26 @@ def _cc_create_compiledb( stages.CompilableSource(prototype_program_source, None), cache_strategy ) + header_ext = prototype_program_source.language_settings.header_extension + src_ext = prototype_program_source.language_settings.file_extension + prog_src_name = f"{name}.{header_ext}" + binding_src_name = f"{name}.{src_ext}" + cmake_languages = [cmake_lists.Language(name="CXX")] + if prototype_program_source.language is languages.Cuda: + cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")] + prototype_project = cmake.CMakeProject( generator_name="Ninja", build_type=build_type, extra_cmake_flags=cmake_flags, root_path=cache_path, source_files={ - f"{name}.hpp": "", - f"{name}.cpp": "", + **{name: "" for name in [binding_src_name, prog_src_name]}, "CMakeLists.txt": cmake_lists.generate_cmakelists_source( - name, prototype_program_source.library_deps, [f"{name}.hpp", f"{name}.cpp"] + name, + prototype_program_source.library_deps, + [binding_src_name, prog_src_name], + cmake_languages, ), }, program_name=name, @@ -290,21 +303,21 @@ def _cc_create_compiledb( entry["command"] .replace(f"CMakeFiles/{name}.dir", ".") .replace(str(cache_path), "$SRC_PATH") - .replace(f"{name}.cpp", "$BINDINGS_FILE") - .replace(f"{name}", "$NAME") + .replace(binding_src_name, "$BINDINGS_FILE") + .replace(name, "$NAME") .replace("-I$SRC_PATH/build/_deps", f"-I{cache_path}/build/_deps") ) entry["file"] = ( entry["file"] .replace(f"CMakeFiles/{name}.dir", ".") .replace(str(cache_path), "$SRC_PATH") - .replace(f"{name}.cpp", "$BINDINGS_FILE") + .replace(binding_src_name, "$BINDINGS_FILE") ) entry["output"] = ( entry["output"] .replace(f"CMakeFiles/{name}.dir", ".") - .replace(f"{name}.cpp", "$BINDINGS_FILE") - .replace(f"{name}", "$NAME") + .replace(binding_src_name, "$BINDINGS_FILE") + .replace(name, "$NAME") ) compile_db_path = cache_path / "compile_commands.json" diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 32c5469333..dacb444207 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -23,7 +23,7 @@ from gt4py.next.otf.step_types import LS, SrcL, TgtL -SourceLanguageType = TypeVar("SourceLanguageType", bound=languages.LanguageTag) +SourceLanguageType = TypeVar("SourceLanguageType", bound=languages.NanobindSrcL) LanguageSettingsType = TypeVar("LanguageSettingsType", bound=languages.LanguageSettings) T = TypeVar("T") diff --git a/src/gt4py/next/otf/languages.py b/src/gt4py/next/otf/languages.py index e2738615ac..b0d01d91ab 100644 --- a/src/gt4py/next/otf/languages.py +++ b/src/gt4py/next/otf/languages.py @@ -57,6 +57,14 @@ class Python(LanguageTag): ... -class Cpp(LanguageTag): +class NanobindSrcL(LanguageTag): + ... + + +class Cpp(NanobindSrcL): settings_class = LanguageWithHeaderFilesSettings ... + + +class Cuda(NanobindSrcL): + settings_class = LanguageWithHeaderFilesSettings diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index d144533798..4c6cdc273d 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,27 +14,21 @@ from __future__ import annotations import dataclasses -from typing import Generic, TypeVar -from gt4py.next.otf import languages, stages, step_types, workflow - - -SrcL = TypeVar("SrcL", bound=languages.LanguageTag) -TgtL = TypeVar("TgtL", bound=languages.LanguageTag) -LS = TypeVar("LS", bound=languages.LanguageSettings) +from gt4py.next.otf import stages, step_types, workflow @dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow(workflow.NamedStepSequence, Generic[SrcL, LS, TgtL]): +class OTFCompileWorkflow(workflow.NamedStepSequence): """The typical compiled backend steps composed into a workflow.""" - translation: step_types.TranslationStep[SrcL, LS] + translation: step_types.TranslationStep bindings: workflow.Workflow[ - stages.ProgramSource[SrcL, LS], - stages.CompilableSource[SrcL, LS, TgtL], + stages.ProgramSource, + stages.CompilableSource, ] compilation: workflow.Workflow[ - stages.CompilableSource[SrcL, LS, TgtL], + stages.CompilableSource, stages.CompiledProgram, ] decoration: workflow.Workflow[stages.CompiledProgram, stages.CompiledProgram] diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/step_types.py index 54fe2e5389..5eeb5c495b 100644 --- a/src/gt4py/next/otf/step_types.py +++ b/src/gt4py/next/otf/step_types.py @@ -50,7 +50,10 @@ def __call__( ... -class CompilationStep(Protocol[SrcL, LS, TgtL]): +class CompilationStep( + workflow.Workflow[stages.CompilableSource[SrcL, LS, TgtL], stages.CompiledProgram], + Protocol[SrcL, LS, TgtL], +): """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" def __call__(self, source: stages.CompilableSource[SrcL, LS, TgtL]) -> stages.CompiledProgram: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 8cd910e40f..645d1f742f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -220,6 +220,7 @@ def visit_FencilDefinition( return self.generic_visit( node, grid_type_str=self._grid_type_str[node.grid_type], + block_sizes=self._block_sizes(node.offset_definitions), **kwargs, ) @@ -261,6 +262,8 @@ def visit_TemporaryAllocation(self, node, **kwargs): ${'\\n'.join(offset_definitions)} ${'\\n'.join(function_definitions)} + ${block_sizes} + inline auto ${id} = [](auto... connectivities__){ return [connectivities__...](auto backend, ${','.join('auto&& ' + p for p in params)}){ auto tmp_alloc__ = gtfn::backend::tmp_allocator(backend); @@ -273,6 +276,18 @@ def visit_TemporaryAllocation(self, node, **kwargs): """ ) + def _block_sizes(self, offset_definitions: list[gtfn_ir.TagDefinition]) -> str: + block_dims = [] + block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) + for i, tag in enumerate(offset_definitions): + if tag.alias is None: + block_dims.append( + f"gridtools::meta::list<{tag.name.id}_t, " + f"gridtools::integral_constant>" + ) + sizes_str = ",\n".join(block_dims) + return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + @classmethod def apply(cls, root: Any, **kwargs: Any) -> str: generated_code = super().apply(root, **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 5e24e855b5..7bf310f4e1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -16,10 +16,11 @@ import dataclasses import warnings -from typing import Any, Final, Optional, TypeVar +from typing import Any, Final, Optional import numpy as np +from gt4py._core import definitions as core_defs from gt4py.eve import trees, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension @@ -32,8 +33,6 @@ from gt4py.next.type_system import type_specifications as ts, type_translation -T = TypeVar("T") - GENERATED_CONNECTIVITY_PARAM_PREFIX = "gt_conn_" @@ -45,14 +44,30 @@ def get_param_description(name: str, obj: Any) -> interface.Parameter: class GTFNTranslationStep( workflow.ChainableWorkflowMixin[ stages.ProgramCall, - stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings], + stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], ], - step_types.TranslationStep[languages.Cpp, languages.LanguageWithHeaderFilesSettings], + step_types.TranslationStep[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], ): - language_settings: languages.LanguageWithHeaderFilesSettings = cpp_interface.CPP_DEFAULT - enable_itir_transforms: bool = True # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 + language_settings: Optional[languages.LanguageWithHeaderFilesSettings] = None + # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 + enable_itir_transforms: bool = True use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + + def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: + match self.device_type: + case core_defs.DeviceType.CUDA: + return languages.LanguageWithHeaderFilesSettings( + formatter_key=cpp_interface.CPP_DEFAULT.formatter_key, + formatter_style=cpp_interface.CPP_DEFAULT.formatter_style, + file_extension="cu", + header_extension="cuh", + ) + case core_defs.DeviceType.CPU: + return cpp_interface.CPP_DEFAULT + case _: + raise self._not_implemented_for_device_type() def _process_regular_arguments( self, @@ -98,7 +113,7 @@ def _process_regular_arguments( isinstance( dim, fbuiltins.FieldOffset ) # TODO(havogt): remove support for FieldOffset as Dimension - or dim.kind == common.DimensionKind.LOCAL + or dim.kind is common.DimensionKind.LOCAL ): # translate sparse dimensions to tuple dtype dim_name = dim.value @@ -159,7 +174,7 @@ def _process_connectivity_args( def __call__( self, inp: stages.ProgramCall, - ) -> stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings]: + ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" program: itir.FencilDefinition = inp.program @@ -189,7 +204,8 @@ def __call__( # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters - args_expr: list[str] = ["gridtools::fn::backend::naive{}", *regular_args_expr] + backend_arg = self._backend_type() + args_expr: list[str] = [backend_arg, *regular_args_expr] function = interface.Function(program.id, tuple(parameters)) decl_body = ( @@ -205,9 +221,9 @@ def __call__( **inp.kwargs, ) source_code = interface.format_source( - self.language_settings, + self._language_settings(), f""" - #include + #include <{self._backend_header()}> #include #include {stencil_src} @@ -215,16 +231,69 @@ def __call__( """.strip(), ) - module = stages.ProgramSource( + module: stages.ProgramSource[ + languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings + ] = stages.ProgramSource( entry_point=function, - library_deps=(interface.LibraryDependency("gridtools", "master"),), + library_deps=(interface.LibraryDependency(self._library_name(), "master"),), source_code=source_code, - language=languages.Cpp, - language_settings=self.language_settings, + language=self._language(), + language_settings=self._language_settings(), ) return module + def _backend_header(self) -> str: + match self.device_type: + case core_defs.DeviceType.CUDA: + return "gridtools/fn/backend/gpu.hpp" + case core_defs.DeviceType.CPU: + return "gridtools/fn/backend/naive.hpp" + case _: + raise self._not_implemented_for_device_type() + + def _backend_type(self) -> str: + match self.device_type: + case core_defs.DeviceType.CUDA: + return "gridtools::fn::backend::gpu{}" + case core_defs.DeviceType.CPU: + return "gridtools::fn::backend::naive{}" + case _: + raise self._not_implemented_for_device_type() + + def _language(self) -> type[languages.NanobindSrcL]: + match self.device_type: + case core_defs.DeviceType.CUDA: + return languages.Cuda + case core_defs.DeviceType.CPU: + return languages.Cpp + case _: + raise self._not_implemented_for_device_type() + + def _language_settings(self) -> languages.LanguageWithHeaderFilesSettings: + return ( + self.language_settings + if self.language_settings is not None + else self._default_language_settings() + ) + + def _library_name(self) -> str: + match self.device_type: + case core_defs.DeviceType.CUDA: + return "gridtools_gpu" + case core_defs.DeviceType.CPU: + return "gridtools_cpu" + case _: + raise self._not_implemented_for_device_type() + + def _not_implemented_for_device_type(self) -> NotImplementedError: + return NotImplementedError( + f"{self.__class__.__name__} is not implemented for " + f"device type {self.device_type.name}" + ) + + +translate_program_cpu: Final[step_types.TranslationStep] = GTFNTranslationStep() -translate_program: Final[ - step_types.TranslationStep[languages.Cpp, languages.LanguageWithHeaderFilesSettings] -] = GTFNTranslationStep() +translate_program_gpu: Final[step_types.TranslationStep] = GTFNTranslationStep( + device_type=core_defs.DeviceType.CUDA +) diff --git a/src/gt4py/next/program_processors/otf_compile_executor.py b/src/gt4py/next/program_processors/otf_compile_executor.py index a22028414b..cd08c16933 100644 --- a/src/gt4py/next/program_processors/otf_compile_executor.py +++ b/src/gt4py/next/program_processors/otf_compile_executor.py @@ -20,15 +20,15 @@ from gt4py.next.program_processors import processor_interface as ppi -SrcL = TypeVar("SrcL", bound=languages.LanguageTag) +SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL) TgtL = TypeVar("TgtL", bound=languages.LanguageTag) LS = TypeVar("LS", bound=languages.LanguageSettings) HashT = TypeVar("HashT") @dataclasses.dataclass(frozen=True) -class OTFCompileExecutor(ppi.ProgramExecutor, Generic[SrcL, LS, TgtL, HashT]): - otf_workflow: recipes.OTFCompileWorkflow[SrcL, LS, TgtL] +class OTFCompileExecutor(ppi.ProgramExecutor): + otf_workflow: recipes.OTFCompileWorkflow name: Optional[str] = None def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: @@ -42,7 +42,7 @@ def __name__(self) -> str: @dataclasses.dataclass(frozen=True) -class CachedOTFCompileExecutor(ppi.ProgramExecutor, Generic[SrcL, LS, TgtL, HashT]): +class CachedOTFCompileExecutor(ppi.ProgramExecutor, Generic[HashT]): otf_workflow: workflow.CachedStep[stages.ProgramCall, stages.CompiledProgram, HashT] name: Optional[str] = None diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn.py similarity index 76% rename from src/gt4py/next/program_processors/runners/gtfn_cpu.py rename to src/gt4py/next/program_processors/runners/gtfn.py index 31b8323474..35c10fe353 100644 --- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -16,11 +16,12 @@ import numpy.typing as npt +from gt4py._core import definitions as core_defs from gt4py.eve.utils import content_hash from gt4py.next import common from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.otf import languages, recipes, stages, workflow -from gt4py.next.otf.binding import cpp_interface, nanobind +from gt4py.next.otf import languages, recipes, stages, step_types, workflow +from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors import otf_compile_executor @@ -91,11 +92,23 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ) -GTFN_DEFAULT_TRANSLATION_STEP = gtfn_module.GTFNTranslationStep( - cpp_interface.CPP_DEFAULT, enable_itir_transforms=True, use_imperative_backend=False +GTFN_DEFAULT_TRANSLATION_STEP: step_types.TranslationStep[ + languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings +] = gtfn_module.GTFNTranslationStep( + enable_itir_transforms=True, + use_imperative_backend=False, + device_type=core_defs.DeviceType.CPU, ) -GTFN_DEFAULT_COMPILE_STEP = compiler.Compiler( +GTFN_GPU_TRANSLATION_STEP: step_types.TranslationStep[ + languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings +] = gtfn_module.GTFNTranslationStep( + enable_itir_transforms=True, + use_imperative_backend=False, + device_type=core_defs.DeviceType.CUDA, +) + +GTFN_DEFAULT_COMPILE_STEP: step_types.CompilationStep = compiler.Compiler( cache_strategy=cache.Strategy.SESSION, builder_factory=compiledb.CompiledbFactory() ) @@ -108,30 +121,35 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ) -run_gtfn = otf_compile_executor.OTFCompileExecutor[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any -](name="run_gtfn", otf_workflow=GTFN_DEFAULT_WORKFLOW) +GTFN_GPU_WORKFLOW = recipes.OTFCompileWorkflow( + translation=GTFN_GPU_TRANSLATION_STEP, + bindings=nanobind.bind_source, + compilation=GTFN_DEFAULT_COMPILE_STEP, + decoration=convert_args, +) + + +run_gtfn = otf_compile_executor.OTFCompileExecutor( + name="run_gtfn", otf_workflow=GTFN_DEFAULT_WORKFLOW +) -run_gtfn_imperative = otf_compile_executor.OTFCompileExecutor[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any -]( +run_gtfn_imperative = otf_compile_executor.OTFCompileExecutor( name="run_gtfn_imperative", otf_workflow=run_gtfn.otf_workflow.replace( translation=run_gtfn.otf_workflow.translation.replace(use_imperative_backend=True), ), ) -run_gtfn_cached = otf_compile_executor.CachedOTFCompileExecutor[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any -]( +run_gtfn_cached = otf_compile_executor.CachedOTFCompileExecutor( name="run_gtfn_cached", otf_workflow=workflow.CachedStep(step=run_gtfn.otf_workflow, hash_function=compilation_hash), ) # todo(ricoh): add API for converting an executor to a cached version of itself and vice versa +run_gtfn_gpu = otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW +) -run_gtfn_with_temporaries = otf_compile_executor.OTFCompileExecutor[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any -]( +run_gtfn_with_temporaries = otf_compile_executor.OTFCompileExecutor( name="run_gtfn_with_temporaries", otf_workflow=run_gtfn.otf_workflow.replace( translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 383716484e..93296ae85f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -21,8 +21,8 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator -from gt4py.next.iterator import embedded, ir as itir -from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip +from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors.runners import gtfn, roundtrip try: @@ -49,9 +49,9 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non @pytest.fixture( params=[ roundtrip.executor, - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ] + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 1402649127..deb1382dfb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -20,22 +20,11 @@ import pytest from gt4py.next import errors -from gt4py.next.common import Field -from gt4py.next.errors.exceptions import TypeError_ from gt4py.next.ffront.decorator import field_operator, program, scan_operator -from gt4py.next.ffront.fbuiltins import broadcast, int32, int64 -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.ffront.fbuiltins import broadcast, int32 from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import ( - IDim, - IField, - IJKField, - IJKFloatField, - JDim, - KDim, - cartesian_case, -) +from next_tests.integration_tests.cases import IDim, IField, IJKFloatField, KDim, cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 61b34460ef..f974e07ad8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -27,13 +27,12 @@ float64, int32, int64, - maximum, minimum, neighbor_sum, where, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -526,12 +525,12 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions def test_solve_triag(cartesian_case): if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("Nested `scan`s requires creating temporaries.") - if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: + if cartesian_case.backend == gtfn.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) @@ -630,7 +629,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: def test_ternary_scan(cartesian_case): - if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: + if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=0.0) @@ -653,7 +652,7 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): - if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: + if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") init = (1, (2, 3)) @@ -690,7 +689,9 @@ def test_scan_nested_tuple_input(cartesian_case): inp2 = gtx.np_as_located_field(KDim)(np.arange(0.0, k_size, 1)) out = gtx.np_as_located_field(KDim)(np.zeros((k_size,))) - prev_levels_iterator = lambda i: range(i + 1) + def prev_levels_iterator(i): + return range(i + 1) + expected = np.asarray( [ reduce(lambda prev, i: prev + inp1[i] + inp2[i], prev_levels_iterator(i), init) @@ -758,9 +759,9 @@ def program_domain(a: cases.IField, out: cases.IField): def test_domain_input_bounds(cartesian_case): if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("FloorDiv not fully supported in gtfn.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py new file mode 100644 index 0000000000..290cece3fa --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py @@ -0,0 +1,43 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest + +import gt4py.next as gtx +from gt4py.next.iterator import embedded +from gt4py.next.program_processors.runners import gtfn + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case # noqa: F401 +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 + fieldview_backend, +) + + +@pytest.mark.requires_gpu +@pytest.mark.parametrize("fieldview_backend", [gtfn.run_gtfn_gpu]) +def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures + import cupy as cp # TODO(ricoh): replace with storages solution when available + + @gtx.field_operator(backend=fieldview_backend) + def testee(a: cases.IJKField) -> cases.IJKField: + return a + + inp_arr = cp.full(shape=(3, 4, 5), fill_value=3, dtype=cp.int32) + outp_arr = cp.zeros_like(inp_arr) + inp = embedded.np_as_located_field(cases.IDim, cases.JDim, cases.KDim)(inp_arr) + outp = embedded.np_as_located_field(cases.IDim, cases.JDim, cases.KDim)(outp_arr) + + testee(inp, out=outp, offset_provider={}) + assert cp.allclose(inp_arr, outp_arr) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 0ae874f3a6..56d5e35b3a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -17,8 +17,8 @@ import pytest import gt4py.next as gtx -from gt4py.next import broadcast, float64, int32, int64, max_over, min_over, neighbor_sum, where -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next import broadcast, float64, int32, max_over, min_over, neighbor_sum, where +from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -30,7 +30,6 @@ Joff, KDim, V2EDim, - Vertex, cartesian_case, unstructured_case, ) @@ -47,9 +46,9 @@ ) def test_maxover_execution_(unstructured_case, strategy): if unstructured_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("`maxover` broken in gtfn, see #1289.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 85826c1ac0..034ce56fee 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -37,7 +37,7 @@ tanh, trunc, ) -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case @@ -69,9 +69,9 @@ def pow(inp1: cases.IField) -> cases.IField: def test_floordiv(cartesian_case): if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: pytest.xfail( "FloorDiv not yet supported." diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index ca29c5b18b..e2bbbaa553 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -52,7 +52,7 @@ xor_, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn +from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 8db9a4c36e..64fb238470 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn, roundtrip from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, @@ -214,9 +214,9 @@ class setup: @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): if fieldview_backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") @@ -234,7 +234,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): - if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: + if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." @@ -256,7 +256,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): - if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: + if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -273,7 +273,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): - if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: + if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 16d839a8ab..4e295e92af 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -79,9 +79,9 @@ def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: from gt4py.next.iterator import transforms diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 42de13ef44..445b73548b 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -16,7 +16,7 @@ import pytest -pytest.importorskip("atlas4py") +pytest.importorskip("atlas4py") # isort: skip import gt4py.next as gtx from gt4py.next.iterator import library @@ -37,7 +37,6 @@ ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.iterator.transforms.pass_manager import LiftMode -from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.multi_feature_tests.iterator_tests.fvm_nabla_setup import ( assert_close, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 7bd028b7c3..af70dd590f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( @@ -75,9 +75,9 @@ def hdiff(inp, coeff, out, x, y): def test_hdiff(hdiff_reference, program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: # TODO(tehrengruber): check if still true from gt4py.next.iterator import transforms diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index f11046cb5d..a0471e8baa 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -19,10 +19,8 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.formatters.gtfn import ( - format_sourcecode as gtfn_format_sourcecode, -) -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters +from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -121,16 +119,16 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): if ( program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - gtfn_format_sourcecode, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, + gtfn_formatters.format_sourcecode, ] and lift_mode == LiftMode.FORCE_INLINE ): pytest.skip("gtfn does only support lifted scans when using temporaries") if ( - program_processor == gtfn_cpu.run_gtfn_with_temporaries + program_processor == gtfn.run_gtfn_with_temporaries or lift_mode == LiftMode.FORCE_TEMPORARIES ): pytest.xfail("tuple_get on columns not supported.") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 92b93ddb63..d475fab3a8 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -30,15 +30,13 @@ shift, ) from gt4py.next.iterator.runtime import fundef -from gt4py.next.program_processors.formatters import gtfn -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.runners import gtfn from next_tests.toy_connectivity import ( C2E, E2V, V2E, V2V, - C2EDim, Cell, E2VDim, Edge, @@ -409,9 +407,9 @@ def shift_sparse_stencil2(inp): def test_shift_sparse_input_field2(program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, + gtfn.run_gtfn, + gtfn.run_gtfn_imperative, + gtfn.run_gtfn_with_temporaries, ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." diff --git a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py index 4e456637cf..c60079eaf1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py @@ -14,7 +14,7 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -37,7 +37,7 @@ def test_different_buffer_sizes(): ) out = gtx.np_as_located_field(IDim, JDim)(np.zeros((out_nx, out_ny), dtype=np.int32)) - @gtx.field_operator(backend=gtfn_cpu.run_gtfn) + @gtx.field_operator(backend=gtfn.run_gtfn) def copy(inp: gtx.Field[[IDim, JDim], gtx.int32]) -> gtx.Field[[IDim, JDim], gtx.int32]: return inp diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 7a62778be1..747431599a 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -22,8 +22,8 @@ from gt4py import eve from gt4py.next.iterator import ir as itir, pretty_parser, pretty_printer, runtime, transforms from gt4py.next.program_processors import processor_interface as ppi -from gt4py.next.program_processors.formatters import gtfn, lisp, type_check -from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip +from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters, lisp, type_check +from gt4py.next.program_processors.runners import double_roundtrip, gtfn, roundtrip try: @@ -78,10 +78,10 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str (roundtrip.executor, True), (type_check.check, False), (double_roundtrip.executor, True), - (gtfn_cpu.run_gtfn, True), - (gtfn_cpu.run_gtfn_imperative, True), - (gtfn_cpu.run_gtfn_with_temporaries, True), - (gtfn.format_sourcecode, False), + (gtfn.run_gtfn, True), + (gtfn.run_gtfn_imperative, True), + (gtfn.run_gtfn_with_temporaries, True), + (gtfn_formatters.format_sourcecode, False), ] + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p[0]), diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 1fab2643b5..45ef85e37c 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -78,7 +78,7 @@ def make_program_source(name: str) -> stages.ProgramSource: entry_point=entry_point, source_code=src, library_deps=[ - interface.LibraryDependency("gridtools", "master"), + interface.LibraryDependency("gridtools_cpu", "master"), ], language=languages.Cpp, language_settings=cpp_interface.CPP_DEFAULT, diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 93be884687..ae5f582e47 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -65,9 +65,9 @@ def fencil_example(): def test_codegen(fencil_example): fencil, parameters = fencil_example - module = gtfn_module.translate_program( + module = gtfn_module.translate_program_cpu( stages.ProgramCall(fencil, parameters, {"offset_provider": {}}) ) assert module.entry_point.name == fencil.id - assert any(d.name == "gridtools" for d in module.library_deps) + assert any(d.name == "gridtools_cpu" for d in module.library_deps) assert module.language is languages.Cpp diff --git a/tox.ini b/tox.ini index e16aaff27f..18a6ff8e84 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests python -m pytest --doctest-modules src{/}gt4py{/}eve -[testenv:next-py{310}-{nomesh,atlas}] +[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.next' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH deps = @@ -81,8 +81,10 @@ set_env = {[testenv]set_env} PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} commands = - nomesh: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas" {posargs} tests{/}next_tests - atlas: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas" {posargs} tests{/}next_tests + nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests + nomesh-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests + atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests + atlas-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests pytest --doctest-modules src{/}gt4py{/}next [testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] From 0650d77b594e89ea092784593d3eaa7559e4fa51 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 30 Oct 2023 14:44:50 +0100 Subject: [PATCH 26/67] feat[next]: Extend DaCe support for offset providers (#1353) Extend support in DaCe backend for offset providers, in order to generate the tasklet code in case of shift expressions with both direct and indirect addressing. Visitors for different types of addressing are merged in one unified visit_shift method. --- .../runners/dace_iterator/itir_to_sdfg.py | 1 - .../runners/dace_iterator/itir_to_tasklet.py | 114 ++++++++---------- 2 files changed, 47 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 580486aa4a..1f9692356e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -196,7 +196,6 @@ def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] ) -> tuple[dace.SDFG, list[str], list[str]]: assert ItirToSDFG._check_no_lifts(node) - assert ItirToSDFG._check_shift_offsets_are_literals(node) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index b28703feef..1634596afa 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -478,17 +478,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): if node.fun.fun.id == "shift": - offset = node.fun.args[0] - assert isinstance(offset, itir.OffsetLiteral) - offset_name = offset.value - assert isinstance(offset_name, str) - if offset_name not in self.offset_provider: - raise ValueError(f"offset provider for `{offset_name}` is missing") - offset_provider = self.offset_provider[offset_name] - if isinstance(offset_provider, Dimension): - return self._visit_direct_addressing(node) - else: - return self._visit_indirect_addressing(node) + return self._visit_shift(node) elif node.fun.fun.id == "reduce": return self._visit_reduce(node) @@ -653,39 +643,7 @@ def _make_shift_for_rest(self, rest, iterator): fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] ) - def _visit_direct_addressing(self, node: itir.FunCall) -> IteratorExpr: - assert isinstance(node.fun, itir.FunCall) - shift = node.fun - assert isinstance(shift, itir.FunCall) - - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - - assert isinstance(tail[0], itir.OffsetLiteral) - offset = tail[0].value - assert isinstance(offset, str) - shifted_dim = self.offset_provider[offset].value - - assert isinstance(tail[1], itir.OffsetLiteral) - shift_amount = tail[1].value - assert isinstance(shift_amount, int) - - args = [ValueExpr(iterator.indices[shifted_dim], dace.int64)] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {shift_amount}" - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "dir_addr" - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - shifted_index[shifted_dim] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -695,40 +653,48 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: iterator = self.visit(node.args[0]) assert isinstance(tail[0], itir.OffsetLiteral) - offset = tail[0].value - assert isinstance(offset, str) + offset_dim = tail[0].value + assert isinstance(offset_dim, str) + offset_node = self.visit(tail[1])[0] - assert isinstance(tail[1], itir.OffsetLiteral) - element = tail[1].value - assert isinstance(element, int) - - if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): - table = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value - - conn = self.context.state.add_access(connectivity_identifier(offset)) + if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): + offset_provider = self.offset_provider[offset_dim] + connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value args = [ - ValueExpr(conn, table.table.dtype), + ValueExpr(connectivity, offset_provider.table.dtype), ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" - else: - offset_provider = self.offset_provider[offset] - assert isinstance(offset_provider, StridedNeighborOffsetProvider) + expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" + elif isinstance(self.offset_provider[offset_dim], StridedNeighborOffsetProvider): + offset_provider = self.offset_provider[offset_dim] shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value - offset_value = iterator.indices[shifted_dim] - args = [ValueExpr(offset_value, dace.int64)] - internals = [f"{offset_value.data}_v"] - expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" + args = [ + ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}" + else: + assert isinstance(self.offset_provider[offset_dim], Dimension) + + shifted_dim = self.offset_provider[offset_dim].value + target_dim = shifted_dim + args = [ + ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" + list(zip(args, internals)), expr, dace.dtypes.int64, "shift" )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -737,6 +703,20 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) + def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: + offset = node.value + assert isinstance(offset, int) + offset_var = unique_var_name() + self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) + offset_node = self.context.state.add_access(offset_var) + tasklet_node = self.context.state.add_tasklet( + "get_offset", {}, {"__out"}, f"__out = {offset}" + ) + self.context.state.add_edge( + tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") + ) + return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] + def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name() result_access = self.context.state.add_access(result_name) From 3c463a62a98ef44cd47b52d9752fcb06f2066c49 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 31 Oct 2023 10:53:59 +0100 Subject: [PATCH 27/67] fix[next]: Improvements in DaCe backend (#1354) This PR contains some fixes and code refactoring in DaCe backend: * (refactoring) Use memlet API for full array subset * Fix for gpu execution: import cupy for sorting of field dimensions. * Fix for symbolic analysis of memlet volume: define symbols before visiting the closure domain in order to allow symbolic analysis of memlet volume --- .../runners/dace_iterator/__init__.py | 37 +++++++++++-------- .../runners/dace_iterator/itir_to_sdfg.py | 6 +-- .../runners/dace_iterator/itir_to_tasklet.py | 33 ++++++++--------- .../runners/dace_iterator/utility.py | 6 +-- .../ffront_tests/test_gpu_backend.py | 4 +- 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 1c1bed9c5e..be63d6809d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -31,6 +31,12 @@ from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims +try: + import cupy as cp +except ImportError: + cp = None + + def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] @@ -49,8 +55,11 @@ def convert_arg(arg: Any): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] - assert isinstance(arg.ndarray, np.ndarray) - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) + if isinstance(arg.ndarray, np.ndarray): + return np.moveaxis(arg.ndarray, range(ndim), dim_indices) + else: + assert cp is not None and isinstance(arg.ndarray, cp.ndarray) + return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -226,24 +235,22 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: @program_executor -def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: - run_dace_iterator( - program, - *args, - **kwargs, - build_cache=_build_cache_cpu, - build_type=_build_type, - run_on_gpu=False, - ) +def run_dace(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_on_gpu = any(not isinstance(arg.ndarray, np.ndarray) for arg in args if is_field(arg)) + if run_on_gpu: + if cp is None: + raise RuntimeError( + f"Non-numpy field argument passed to program {program.id} but module cupy not installed" + ) + if not all(isinstance(arg.ndarray, cp.ndarray) for arg in args if is_field(arg)): + raise RuntimeError("Execution on GPU requires all fields to be stored as cupy arrays") -@program_executor -def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, *args, **kwargs, - build_cache=_build_cache_gpu, + build_cache=_build_cache_gpu if run_on_gpu else _build_cache_cpu, build_type=_build_type, - run_on_gpu=True, + run_on_gpu=run_on_gpu, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 1f9692356e..9e9cc4bf29 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -258,9 +258,9 @@ def visit_StencilClosure( # Update symbol table and get output domain of the closure for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): + dtype = as_dace_type(type_) + closure_sdfg.add_symbol(name, dtype) if name in input_names: - dtype = as_dace_type(type_) - closure_sdfg.add_symbol(name, dtype) out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) out_tasklet = closure_init_state.add_tasklet( @@ -272,7 +272,7 @@ def visit_StencilClosure( closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: - program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_)) + program_arg_syms[name] = SymbolExpr(name, dtype) closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 1634596afa..5d47cad909 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -34,6 +34,7 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, + create_memlet_at, create_memlet_full, filter_neighbor_tables, flatten_list, @@ -199,7 +200,6 @@ def builtin_neighbors( result_access = state.add_access(result_name) table_name = connectivity_identifier(offset_dim) - table_array = sdfg.arrays[table_name] # generate unique map index name to avoid conflict with other maps inside same state index_name = unique_name("__neigh_idx") @@ -225,14 +225,14 @@ def builtin_neighbors( state.add_access(table_name), me, shift_tasklet, - memlet=dace.Memlet(data=table_name, subset=",".join(f"0:{s}" for s in table_array.shape)), + memlet=create_memlet_full(table_name, sdfg.arrays[table_name]), dst_conn="__table", ) state.add_memlet_path( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet(data=iterator.indices[shifted_dim].data, subset="0"), + memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), dst_conn="__idx", ) state.add_edge( @@ -240,28 +240,25 @@ def builtin_neighbors( "__result", data_access_tasklet, "__idx", - dace.Memlet(data=idx_name, subset="0"), + dace.Memlet.simple(idx_name, "0"), ) # select full shape only in the neighbor-axis dimension - field_subset = [ + field_subset = tuple( f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) - ] + ) state.add_memlet_path( iterator.field, me, data_access_tasklet, - memlet=dace.Memlet( - data=iterator.field.data, - subset=",".join(field_subset), - ), + memlet=create_memlet_at(iterator.field.data, field_subset), dst_conn="__field", ) state.add_memlet_path( data_access_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset=index_name), + memlet=dace.Memlet.simple(result_name, index_name), src_conn="__result", ) @@ -438,7 +435,7 @@ def visit_Lambda( result_access, None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), + dace.Memlet.simple(result_access.data, "0", wcr_str=context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: @@ -616,7 +613,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: deref_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset=index_name), + memlet=dace.Memlet.simple(result_name, index_name), src_conn="__result", ) @@ -738,13 +735,13 @@ def _visit_reduce(self, node: itir.FunCall): assert isinstance(op_name, itir.SymRef) init = node.fun.args[1] - nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0] + reduce_array_desc = neighbors_expr.value.desc(self.context.body) self.context.body.add_scalar(result_name, result_dtype, transient=True) op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") reduce_tasklet = self.context.state.add_tasklet( "reduce", - code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", + code=f"__result = {init}\nfor __idx in range({reduce_array_desc.shape[0]}):\n __result = {op_str}", inputs={"__values"}, outputs={"__result"}, ) @@ -753,14 +750,14 @@ def _visit_reduce(self, node: itir.FunCall): None, reduce_tasklet, "__values", - dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"), + create_memlet_full(neighbors_expr.value.data, reduce_array_desc), ) self.context.state.add_edge( reduce_tasklet, "__result", result_access, None, - dace.Memlet(data=result_name, subset="0"), + dace.Memlet.simple(result_name, "0"), ) else: assert isinstance(node.fun, itir.FunCall) @@ -973,7 +970,7 @@ def closure_to_tasklet_sdfg( tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") access = state.add_access(name) idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) + state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 1fdd022a49..c17a39ef2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -48,14 +48,12 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): - bounds = [(0, size) for size in source_array.shape] - subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet.simple(source_identifier, subset) + return dace.Memlet.from_array(source_identifier, source_array) def create_memlet_at(source_identifier: str, index: tuple[str, ...]): subset = ", ".join(index) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py index 290cece3fa..381cc740c5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py @@ -16,7 +16,7 @@ import gt4py.next as gtx from gt4py.next.iterator import embedded -from gt4py.next.program_processors.runners import gtfn +from gt4py.next.program_processors.runners import dace_iterator, gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case # noqa: F401 @@ -26,7 +26,7 @@ @pytest.mark.requires_gpu -@pytest.mark.parametrize("fieldview_backend", [gtfn.run_gtfn_gpu]) +@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace, gtfn.run_gtfn_gpu]) def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures import cupy as cp # TODO(ricoh): replace with storages solution when available From 4d8df6938c04839f427f642e7dcce9ee53f6b149 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:27:55 +0100 Subject: [PATCH 28/67] feat[next]: Limit use of global type inference in CollapseTuple pass (#1355) CollapseTuple configurable such that whether the ITIR type inference is used to get the tuple size or the simple heuristics can be configured using a boolean flag to the pass Execute with ITIR type inference once in loop in pass manager and in all subsequent runs use the simple heuristics --- .../iterator/transforms/collapse_tuple.py | 46 +++++++++++++++---- .../next/iterator/transforms/pass_manager.py | 6 ++- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 86f21072e5..7d710fc919 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -11,18 +11,36 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - from dataclasses import dataclass +from typing import Optional from gt4py import eve from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference -def _get_tuple_size(type_: type_inference.Type) -> int: - assert isinstance(type_, it_type_inference.Val) and isinstance( - type_.dtype, it_type_inference.Tuple - ) +class UnknownLength: + pass + + +def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | type[UnknownLength]: + if node_types: + type_ = node_types[id(elem)] + # global inference should always give a length, function should fail otherwise + assert isinstance(type_, it_type_inference.Val) and isinstance( + type_.dtype, it_type_inference.Tuple + ) + else: + # use local type inference if no global information is available + assert isinstance(elem, ir.Node) + type_ = it_type_inference.infer(elem) + + if not ( + isinstance(type_, it_type_inference.Val) + and isinstance(type_.dtype, it_type_inference.Tuple) + ): + return UnknownLength + return len(type_.dtype) @@ -38,8 +56,8 @@ class CollapseTuple(eve.NodeTranslator): ignore_tuple_size: bool collapse_make_tuple_tuple_get: bool collapse_tuple_get_make_tuple: bool - - _node_types: dict[int, type_inference.Type] + use_global_type_inference: bool + _node_types: Optional[dict[int, type_inference.Type]] = None @classmethod def apply( @@ -50,6 +68,7 @@ def apply( # the following options are mostly for allowing separate testing of the modes collapse_make_tuple_tuple_get: bool = True, collapse_tuple_get_make_tuple: bool = True, + use_global_type_inference: bool = False, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -57,15 +76,22 @@ def apply( If `ignore_tuple_size`, apply the transformation even if length of the inner tuple is greater than the length of the outer tuple. """ - node_types = it_type_inference.infer_all(node) - + node_types = it_type_inference.infer_all(node) if use_global_type_inference else None return cls( ignore_tuple_size, collapse_make_tuple_tuple_get, collapse_tuple_get_make_tuple, + use_global_type_inference, node_types, ).visit(node) + return cls( + ignore_tuple_size, + collapse_make_tuple_tuple_get, + collapse_tuple_get_make_tuple, + use_global_type_inference, + ).visit(node) + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if ( self.collapse_make_tuple_tuple_get @@ -86,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # tuple argument differs, just continue with the rest of the tree return self.generic_visit(node) - if self.ignore_tuple_size or _get_tuple_size(self._node_types[id(first_expr)]) == len( + if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len( node.args ): return first_expr diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0ff3ec25c7..b0db04eb5f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -106,7 +106,11 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined) + inlined = CollapseTuple.apply( + inlined, + # to limit number of times global type inference is executed, only in the last iterations. + use_global_type_inference=inlined == ir, + ) if inlined == ir: break From 0df592d194bc80c7a06e3b7a916cd00d43498af9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 15 Nov 2023 13:30:12 +0100 Subject: [PATCH 29/67] feat[next] high-level field storage API (#1319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce user API to allocate fields in `gt4py.next`. Summary of main changes: - Introduce FieldBuffer allocator protocols and implementations - Introduce the concept of Backend as ProgramExecutor & Allocator - Replace np_as_located_field with as_field - Make NdArrayField public - Fixes for _core.definitions typings - Fixes and extensions of eve.extended_typing - Refactor the handling of backends/program processors in the testing infrastructure with string enumerations representing the qualified name of the Python symbol, which can be loaded on demand - Rename some executor symbols and modules - Minor style changes to imports and imported symbols to follow coding guidelines. Open"To Do"s for future PRs: - Add support for `GTFieldInterface` protocol in cartesian and use it instead of `NextGTDimsInterface` protocol in next. - Add support for `aligned_index != None` in `FieldBufferAllocator` implementations - Add support for zero-copy construction of fields in `constructors.as_field()` --------- Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> Co-authored-by: Enrique Gonzalez Paredes Co-authored-by: Rico Häuselmann Co-authored-by: nfarabullini --- .../0008-Mapping_Domain_to_Cpp-Backend.md | 5 +- docs/development/ADRs/Index.md | 1 + docs/user/next/QuickstartGuide.md | 50 ++- src/gt4py/_core/definitions.py | 66 ++-- src/gt4py/eve/codegen.py | 2 +- src/gt4py/eve/extended_typing.py | 62 +++- src/gt4py/eve/utils.py | 56 ++- src/gt4py/next/__init__.py | 7 + src/gt4py/next/allocators.py | 349 ++++++++++++++++++ src/gt4py/next/common.py | 63 +++- src/gt4py/next/constructors.py | 297 +++++++++++++++ src/gt4py/next/embedded/nd_array_field.py | 42 +-- src/gt4py/next/ffront/decorator.py | 18 +- src/gt4py/next/iterator/embedded.py | 7 +- .../program_processors/formatters/gtfn.py | 2 +- .../formatters/pretty_print.py | 26 +- .../formatters/type_check.py | 2 +- .../otf_compile_executor.py | 31 +- .../program_processors/processor_interface.py | 163 ++++++-- .../runners/dace_iterator/__init__.py | 49 ++- .../runners/double_roundtrip.py | 24 +- .../next/program_processors/runners/gtfn.py | 64 +++- .../program_processors/runners/roundtrip.py | 45 ++- src/gt4py/storage/allocators.py | 231 ++++++------ src/gt4py/storage/cartesian/interface.py | 17 +- src/gt4py/storage/cartesian/utils.py | 43 ++- .../unit_tests/test_extended_typing.py | 63 ++++ tests/eve_tests/unit_tests/test_utils.py | 15 + tests/next_tests/__init__.py | 4 + tests/next_tests/exclusion_matrices.py | 78 +++- tests/next_tests/integration_tests/cases.py | 23 +- .../ffront_tests/ffront_test_utils.py | 37 +- .../ffront_tests/test_execution.py | 12 +- .../ffront_tests/test_external_local_field.py | 4 +- .../ffront_tests/test_gpu_backend.py | 18 +- .../ffront_tests/test_gt4py_builtins.py | 2 +- .../test_math_builtin_execution.py | 6 +- .../ffront_tests/test_math_unary_builtins.py | 2 +- .../ffront_tests/test_program.py | 2 +- .../ffront_tests/test_scalar_if.py | 3 +- .../iterator_tests/test_builtins.py | 41 +- .../test_cartesian_offset_provider.py | 4 +- .../iterator_tests/test_conditional.py | 4 +- .../iterator_tests/test_constant.py | 4 +- .../test_horizontal_indirection.py | 18 +- .../iterator_tests/test_implicit_fencil.py | 4 +- .../feature_tests/iterator_tests/test_scan.py | 7 +- .../test_strided_offset_provider.py | 7 +- .../iterator_tests/test_trivial.py | 16 +- .../iterator_tests/test_tuple.py | 99 ++--- .../feature_tests/test_util_cases.py | 8 +- .../ffront_tests/test_icon_like_scan.py | 24 +- .../iterator_tests/test_anton_toy.py | 6 +- .../iterator_tests/test_column_stencil.py | 52 +-- .../iterator_tests/test_fvm_nabla.py | 50 +-- .../iterator_tests/test_hdiff.py | 6 +- .../iterator_tests/test_vertical_advection.py | 4 +- .../test_with_toy_connectivity.py | 56 +-- .../otf_tests/test_gtfn_workflow.py | 6 +- tests/next_tests/unit_tests/conftest.py | 64 ++-- .../embedded_tests/test_nd_array_field.py | 4 +- .../iterator_tests/test_runtime_domain.py | 7 +- .../gtfn_tests/test_gtfn_module.py | 2 +- .../test_processor_interface.py | 53 +++ .../next_tests/unit_tests/test_allocators.py | 193 ++++++++++ .../unit_tests/test_constructors.py | 175 +++++++++ 66 files changed, 2277 insertions(+), 628 deletions(-) create mode 100644 src/gt4py/next/allocators.py create mode 100644 src/gt4py/next/constructors.py create mode 100644 tests/next_tests/unit_tests/test_allocators.py create mode 100644 tests/next_tests/unit_tests/test_constructors.py diff --git a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md index 23b75c6df5..a1ee8575d2 100644 --- a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md +++ b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md @@ -7,10 +7,13 @@ tags: [] - **Status**: valid - **Authors**: Hannes Vogt (@havogt) - **Created**: 2022-06-29 -- **Updated**: 2022-06-29 +- **Updated**: 2023-11-08 This document proposes a (temporary) solution for mapping domain dimensions to field dimensions. +> [!NOTE] +> This ADR was written before the integration of `gt4py.storage` into `gt4py.next`, so the example is using `np_as_located_field` (now deprecated) instead of `gtx.as_field.partial`. The idea conveyed by the example remains unchanged. + ## Context The Python embedded execution for Iterator IR keeps track of the current location type of an iterator (allows safety checks) while the C++ backend does not. diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/Index.md index 09d2273ee9..24272d9cee 100644 --- a/docs/development/ADRs/Index.md +++ b/docs/development/ADRs/Index.md @@ -45,6 +45,7 @@ _None_ - [0006 - C++ Backend](0006-Cpp-Backend.md) - [0007 - Fencil Processors](0007-Fencil-Processors.md) - [0008 - Mapping Domain to Cpp Backend](0008-Mapping_Domain_to_Cpp-Backend.md) +- [0016 - Multiple Backends and Build Systems](0016-Multiple-Backends-and-Build-Systems.md) ### Python Integration diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index bf6466ade6..1ae1db4d92 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -51,7 +51,7 @@ from gt4py.next import float64, neighbor_sum, where #### Fields -Fields store data as a multi-dimensional array, and are defined over a set of named dimensions. The code snippet below defines two named dimensions, _cell_ and _K_, and creates the fields `a` and `b` over their cartesian product using the `np_as_located_field` helper function. The fields contain the values 2 for `a` and 3 for `b` for all entries. +Fields store data as a multi-dimensional array, and are defined over a set of named dimensions. The code snippet below defines two named dimensions, _Cell_ and _K_, and creates the fields `a` and `b` over their cartesian product using the `gtx.as_field` helper function. The fields contain the values 2 for `a` and 3 for `b` for all entries. ```{code-cell} ipython3 CellDim = gtx.Dimension("Cell") @@ -63,8 +63,20 @@ grid_shape = (num_cells, num_layers) a_value = 2.0 b_value = 3.0 -a = gtx.np_as_located_field(CellDim, KDim)(np.full(shape=grid_shape, fill_value=a_value, dtype=np.float64)) -b = gtx.np_as_located_field(CellDim, KDim)(np.full(shape=grid_shape, fill_value=b_value, dtype=np.float64)) +a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=a_value, dtype=np.float64)) +b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=b_value, dtype=np.float64)) +``` + +Additional numpy-equivalent constructors are available, namely `ones`, `zeros`, `empty`, `full`. These require domain, dtype, and allocator (e.g. a backend) specifications. + +```{code-cell} ipython3 +from gt4py._core import definitions as core_defs +array_of_ones_numpy = np.ones((grid_shape[0], grid_shape[1])) +field_of_ones = gtx.constructors.ones( + domain={I: range(grid_shape[0]), J: range(grid_shape[0])}, + dtype=core_defs.dtype(np.float64), + allocator=gtx.program_processors.runners.roundtrip.backend +) ``` _Note: The interface to construct fields is provisional only and will change soon._ @@ -87,7 +99,7 @@ def add(a: gtx.Field[[CellDim, KDim], float64], You can call field operators from [programs](#Programs), other field operators, or directly. The code snippet below shows a direct call, in which case you have to supply two additional arguments: `out`, which is a field to write the return value to, and `offset_provider`, which is left empty for now. The result of the field operator is a field with all entries equal to 5, but for brevity, only the average and the standard deviation of the entries are printed: ```{code-cell} ipython3 -result = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) add(a, b, out=result, offset_provider={}) print("{} + {} = {} ± {}".format(a_value, b_value, np.average(np.asarray(result)), np.std(np.asarray(result)))) @@ -113,7 +125,7 @@ def run_add(a : gtx.Field[[CellDim, KDim], float64], You can execute the program by simply calling it: ```{code-cell} ipython3 -result = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) run_add(a, b, result, offset_provider={}) print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(np.asarray(result)), np.std(np.asarray(result)))) @@ -200,8 +212,8 @@ cell_to_edge_table = np.array([ Let's start by defining two fields: one over the cells and another one over the edges. The field over cells serves input for subsequent calculations and is therefore filled up with values, whereas the field over the edges stores the output of the calculations and is therefore left blank. ```{code-cell} ipython3 -cell_values = gtx.np_as_located_field(CellDim)(np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) -edge_values = gtx.np_as_located_field(EdgeDim)(np.zeros((12,))) +cell_values = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) +edge_values = gtx.as_field([EdgeDim], np.zeros((12,))) ``` | ![cell_values](connectivity_cell_field.svg) | @@ -295,8 +307,8 @@ This function takes 3 input arguments: In the case where the true and false branches are either fields or scalars, the resulting output will be a field including all dimensions from all inputs. For example: ```{code-cell} ipython3 -mask = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape, dtype=bool)) -result_where = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) +mask = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape, dtype=bool)) +result_where = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) b = 6.0 @gtx.field_operator @@ -313,8 +325,8 @@ print("where return: {}".format(np.asarray(result_where))) The `where` supports the return of tuples of fields. To perform promotion of dimensions and dtype of the output, all arguments are analyzed and promoted as in the above section. ```{code-cell} ipython3 -result_1 = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) -result_2 = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) +result_1 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +result_2 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) @gtx.field_operator def _conditional_tuple(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float @@ -338,13 +350,13 @@ The `where` builtin also allows for nesting of tuples. In this scenario, it will and then combine results to match the return type: ```{code-cell} ipython3 -a = gtx.np_as_located_field(CellDim, KDim)(np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) -b = gtx.np_as_located_field(CellDim, KDim)(np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) -c = gtx.np_as_located_field(CellDim, KDim)(np.full(shape=grid_shape, fill_value=4.0, dtype=np.float64)) -d = gtx.np_as_located_field(CellDim, KDim)(np.full(shape=grid_shape, fill_value=5.0, dtype=np.float64)) +a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) +b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) +c = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=4.0, dtype=np.float64)) +d = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=5.0, dtype=np.float64)) -result_1 = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) -result_2 = gtx.np_as_located_field(CellDim, KDim)(np.zeros(shape=grid_shape)) +result_1 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +result_2 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) @gtx.field_operator def _conditional_tuple_nested( @@ -402,7 +414,7 @@ edge_weights = np.array([ [0, -1, -1], # cell 5 ], dtype=np.float64) -edge_weight_field = gtx.np_as_located_field(CellDim, C2EDim)(edge_weights) +edge_weight_field = gtx.as_field([CellDim, C2EDim], edge_weights) ``` Now you have everything to implement the pseudo-laplacian. Its field operator requires the cell field and the edge weights as inputs, and outputs a cell field of the same shape as the input. @@ -428,7 +440,7 @@ def run_pseudo_laplacian(cells : gtx.Field[[CellDim], float64], out : gtx.Field[[CellDim], float64]): pseudo_lap(cells, edge_weights, out=out) -result_pseudo_lap = gtx.np_as_located_field(CellDim)(np.zeros(shape=(6,))) +result_pseudo_lap = gtx.as_field([CellDim], np.zeros(shape=(6,))) run_pseudo_laplacian(cell_values, edge_weight_field, diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 059ba6c24c..7b318bc2de 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -25,6 +25,7 @@ import numpy as np import numpy.typing as npt +import gt4py.eve as eve from gt4py.eve.extended_typing import ( TYPE_CHECKING, Any, @@ -71,33 +72,33 @@ float64 = np.float64 BoolScalar: TypeAlias = Union[bool_, bool] -BoolT = TypeVar("BoolT", bound=Union[bool_, bool]) +BoolT = TypeVar("BoolT", bound=BoolScalar) BOOL_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], BoolScalar.__args__) # type: ignore[attr-defined] IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] -IntT = TypeVar("IntT", bound=Union[int8, int16, int32, int64, int]) +IntT = TypeVar("IntT", bound=IntScalar) INT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], IntScalar.__args__) # type: ignore[attr-defined] UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] -UnsignedIntT = TypeVar("UnsignedIntT", bound=Union[uint8, uint16, uint32, uint64]) +UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) UINT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], UnsignedIntScalar.__args__) # type: ignore[attr-defined] IntegralScalar: TypeAlias = Union[IntScalar, UnsignedIntScalar] -IntegralT = TypeVar("IntegralT", bound=Union[IntScalar, UnsignedIntScalar]) +IntegralT = TypeVar("IntegralT", bound=IntegralScalar) INTEGRAL_TYPES: Final[Tuple[type, ...]] = (*INT_TYPES, *UINT_TYPES) FloatingScalar: TypeAlias = Union[float32, float64, float] -FloatingT = TypeVar("FloatingT", bound=Union[float32, float64, float]) +FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], FloatingScalar.__args__) # type: ignore[attr-defined] #: Type alias for all scalar types supported by GT4Py Scalar: TypeAlias = Union[BoolScalar, IntegralScalar, FloatingScalar] -ScalarT = TypeVar("ScalarT", bound=Union[BoolScalar, IntegralScalar, FloatingScalar]) +ScalarT = TypeVar("ScalarT", bound=Scalar) SCALAR_TYPES: Final[tuple[type, ...]] = (*BOOL_TYPES, *INTEGRAL_TYPES, *FLOAT_TYPES) @@ -139,7 +140,7 @@ def is_valid_tensor_shape( # -- Data type descriptors -- -class DTypeKind(enum.Enum): +class DTypeKind(eve.StrEnum): """ Kind of a specific data type. @@ -368,7 +369,7 @@ def __gt_origin__(self) -> Tuple[int, ...]: # -- Device representation -- -class DeviceType(enum.Enum): +class DeviceType(enum.IntEnum): """The type of the device where a memory buffer is allocated. Enum values taken from DLPack reference implementation at: @@ -385,8 +386,31 @@ class DeviceType(enum.Enum): ROCM = 10 +CPUDeviceTyping: TypeAlias = Literal[DeviceType.CPU] +CUDADeviceTyping: TypeAlias = Literal[DeviceType.CUDA] +CPUPinnedDeviceTyping: TypeAlias = Literal[DeviceType.CPU_PINNED] +OpenCLDeviceTyping: TypeAlias = Literal[DeviceType.OPENCL] +VulkanDeviceTyping: TypeAlias = Literal[DeviceType.VULKAN] +MetalDeviceTyping: TypeAlias = Literal[DeviceType.METAL] +VPIDeviceTyping: TypeAlias = Literal[DeviceType.VPI] +ROCMDeviceTyping: TypeAlias = Literal[DeviceType.ROCM] + + +DeviceTypeT = TypeVar( + "DeviceTypeT", + CPUDeviceTyping, + CUDADeviceTyping, + CPUPinnedDeviceTyping, + OpenCLDeviceTyping, + VulkanDeviceTyping, + MetalDeviceTyping, + VPIDeviceTyping, + ROCMDeviceTyping, +) + + @dataclasses.dataclass(frozen=True) -class Device: +class Device(Generic[DeviceTypeT]): """ Representation of a computing device. @@ -397,10 +421,10 @@ class Device: core number, for `DeviceType.CUDA` it could be the CUDA device number, etc. """ - device_type: DeviceType + device_type: DeviceTypeT device_id: int - def __iter__(self) -> Iterator[DeviceType | int]: + def __iter__(self) -> Iterator[DeviceTypeT | int]: yield self.device_type yield self.device_id @@ -409,7 +433,7 @@ def __iter__(self) -> Iterator[DeviceType | int]: SliceLike = Union[int, Tuple[int, ...], None, slice, "NDArrayObject"] -class NDArrayObjectProto(Protocol): +class NDArrayObject(Protocol): @property def ndim(self) -> int: ... @@ -422,7 +446,7 @@ def shape(self) -> tuple[int, ...]: def dtype(self) -> Any: ... - def __getitem__(self, item: SliceLike) -> NDArrayObject: + def __getitem__(self, item: Any) -> NDArrayObject: ... def __abs__(self) -> NDArrayObject: @@ -434,38 +458,32 @@ def __neg__(self) -> NDArrayObject: def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __radd__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + def __radd__(self, other: Any) -> NDArrayObject: ... def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rsub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + def __rsub__(self, other: Any) -> NDArrayObject: ... def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rmul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + def __rmul__(self, other: Any) -> NDArrayObject: ... def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rfloordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + def __rfloordiv__(self, other: Any) -> NDArrayObject: ... def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rtruediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + def __rtruediv__(self, other: Any) -> NDArrayObject: ... def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - - -NDArrayObject = Union[npt.NDArray, "CuPyNDArray", "JaxNDArray", NDArrayObjectProto] -NDArrayObjectT = TypeVar( - "NDArrayObjectT", npt.NDArray, "CuPyNDArray", "JaxNDArray", NDArrayObjectProto, covariant=True -) diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 76fea347f0..3a964c92a9 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -155,7 +155,7 @@ def format_cpp_source( ) -> str: """Format C++ source code using clang-format.""" assert isinstance(_CLANG_FORMAT_EXECUTABLE, str) - args = [_CLANG_FORMAT_EXECUTABLE] + args = [_CLANG_FORMAT_EXECUTABLE, "--assume-filename=_gt4py_generated_file.cpp"] if style: args.append(f"--style={style}") if fallback_style: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 3b8373ade1..17462a37ff 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -36,6 +36,7 @@ from typing import * # noqa: F403 from typing import overload # Only needed to avoid false flake8 errors +import numpy.typing as npt import typing_extensions as _typing_extensions from typing_extensions import * # type: ignore[assignment,no-redef] # noqa: F403 @@ -236,6 +237,21 @@ def hexdigest(self) -> str: # -- Third party protocols -- +class SupportsArray(Protocol): + def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: + ... + + +def supports_array(value: Any) -> TypeGuard[SupportsArray]: + return hasattr(value, "__array__") + + +class ArrayInterface(Protocol): + @property + def __array_interface__(self) -> Dict[str, Any]: + ... + + class ArrayInterfaceTypedDict(TypedDict): shape: Tuple[int, ...] typestr: str @@ -248,11 +264,19 @@ class ArrayInterfaceTypedDict(TypedDict): class StrictArrayInterface(Protocol): - __array_interface__: ArrayInterfaceTypedDict + @property + def __array_interface__(self) -> ArrayInterfaceTypedDict: + ... -class ArrayInterface(Protocol): - __array_interface__: Dict[str, Any] +def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: + return hasattr(value, "__array_interface__") + + +class CUDAArrayInterface(Protocol): + @property + def __cuda_array_interface__(self) -> Dict[str, Any]: + ... class CUDAArrayInterfaceTypedDict(TypedDict): @@ -267,25 +291,45 @@ class CUDAArrayInterfaceTypedDict(TypedDict): class StrictCUDAArrayInterface(Protocol): - __cuda_array_interface__: CUDAArrayInterfaceTypedDict + @property + def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: + ... -class CUDAArrayInterface(Protocol): - __cuda_array_interface__: Dict[str, Any] +def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: + """Check if the given value supports the CUDA Array Interface.""" + return hasattr(value, "__cuda_array_interface__") -PyCapsule = NewType("PyCapsule", object) DLPackDevice = Tuple[int, int] -class DLPackBuffer(Protocol): - def __dlpack__(self, stream: Optional[int] = None) -> PyCapsule: +class MultiStreamDLPackBuffer(Protocol): + def __dlpack__(self, *, stream: Optional[int] = None) -> Any: + ... + + def __dlpack_device__(self) -> DLPackDevice: + ... + + +class SingleStreamDLPackBuffer(Protocol): + def __dlpack__(self, *, stream: None = None) -> Any: ... def __dlpack_device__(self) -> DLPackDevice: ... +DLPackBuffer: TypeAlias = Union[MultiStreamDLPackBuffer, SingleStreamDLPackBuffer] + + +def supports_dlpack(value: Any) -> TypeGuard[DLPackBuffer]: + """Check if a given object supports the DLPack protocol.""" + return callable(getattr(value, "__dlpack__", None)) and callable( + getattr(value, "__dlpack_device__", None) + ) + + class DevToolsPrettyPrintable(Protocol): """Used by python-devtools (https://python-devtools.helpmanual.io/).""" diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index bdbc34f445..7104f7658f 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -228,6 +228,59 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]: _P = ParamSpec("_P") +_T = TypeVar("_T") + + +class fluid_partial(functools.partial): + """Create a `functools.partial` with support for multiple applications calling `.partial()`.""" + + def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: + return fluid_partial(self, *args, **kwargs) + + +@overload +def with_fluid_partial( + func: Literal[None] = None, *args: Any, **kwargs: Any +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + ... + + +@overload +def with_fluid_partial( # noqa: F811 # redefinition of unused function + func: Callable[_P, _T], *args: Any, **kwargs: Any +) -> Callable[_P, _T]: + ... + + +def with_fluid_partial( # noqa: F811 # redefinition of unused function + func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any +) -> Union[Callable[..., Any], Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Add a `partial` attribute to the decorated function. + + The `partial` attribute is a function that behaves like `functools.partial`, + but also supports partial application of the decorated function. It can be + used both as a bare or a parameterized decorator. + + Arguments: + func: The function to decorate. + + Returns: + Returns the decorated function with an extra `.partial()` attribute. + + Example: + >>> @with_fluid_partial + ... def add(a, b): + ... return a + b + ... + >>> add.partial(1)(2) + 3 + """ + + def _decorator(func: Callable[..., Any]) -> Callable[..., Any]: + func.partial = fluid_partial(functools.partial, func, *args, **kwargs) # type: ignore[attr-defined] # add attribute + return func + + return _decorator(func) if func is not None else _decorator @overload @@ -318,9 +371,6 @@ def _decorator(base_cls: Type) -> Type: return _decorator -_T = TypeVar("_T") - - def noninstantiable(cls: Type[_T]) -> Type[_T]: """Make a class without abstract method non-instantiable (subclasses should be instantiable).""" if not isinstance(cls, type): diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index cc35899668..696c4f174c 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,6 +25,7 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType +from .constructors import as_field, empty, full, ones, zeros from .embedded import ( # Just for registering field implementations nd_array_field as _nd_array_field, ) @@ -52,6 +53,12 @@ "DimensionKind", "Field", "GridType", + # from constructors + "empty", + "zeros", + "ones", + "full", + "as_field", # from iterator "NeighborTableOffsetProvider", "StridedNeighborOffsetProvider", diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py new file mode 100644 index 0000000000..58600d8cda --- /dev/null +++ b/src/gt4py/next/allocators.py @@ -0,0 +1,349 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import abc +import dataclasses + +import numpy as np + +import gt4py._core.definitions as core_defs +import gt4py.next.common as common +import gt4py.storage.allocators as core_allocators +from gt4py.eve.extended_typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Literal, + Optional, + Protocol, + Sequence, + TypeAlias, + TypeGuard, + cast, +) + + +try: + import cupy as cp +except ImportError: + cp = None + + +CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( + None + if not cp + else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA) +) + + +FieldLayoutMapper: TypeAlias = Callable[ + [Sequence[common.Dimension]], core_allocators.BufferLayoutMap +] + + +class FieldBufferAllocatorProtocol(Protocol[core_defs.DeviceTypeT]): + """Protocol for buffer allocators used to allocate memory for fields with a given domain.""" + + @property + @abc.abstractmethod + def __gt_device_type__(self) -> core_defs.DeviceTypeT: + ... + + @abc.abstractmethod + def __gt_allocate__( + self, + domain: common.Domain, + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int = 0, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ... + + +def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: + return hasattr(obj, "__gt_device_type__") and hasattr(obj, "__gt_allocate__") + + +def is_field_allocator_for( + obj: Any, device: core_defs.DeviceTypeT +) -> TypeGuard[FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]]: + return is_field_allocator(obj) and obj.__gt_device_type__ is device + + +class FieldBufferAllocatorFactoryProtocol(Protocol[core_defs.DeviceTypeT]): + """Protocol for device-specific buffer allocator factories for fields.""" + + @property + @abc.abstractmethod + def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: + ... + + +def is_field_allocator_factory(obj: Any) -> TypeGuard[FieldBufferAllocatorFactoryProtocol]: + return hasattr(obj, "__gt_allocator__") + + +def is_field_allocator_factory_for( + obj: Any, device: core_defs.DeviceTypeT +) -> TypeGuard[FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]]: + return is_field_allocator_factory(obj) and obj.__gt_allocator__.__gt_device_type__ is device + + +FieldBufferAllocationUtil = ( + FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] + | FieldBufferAllocatorFactoryProtocol[core_defs.DeviceTypeT] +) + + +def is_field_allocation_tool(obj: Any) -> TypeGuard[FieldBufferAllocationUtil]: + return is_field_allocator(obj) or is_field_allocator_factory(obj) + + +def get_allocator( + obj: Any, + *, + default: Optional[FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]] = None, + strict: bool = False, +) -> Optional[FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]]: + """ + Return a field-buffer-allocator from an object assumed to be an allocator or an allocator factory. + + A default allocator can be provided as fallback in case `obj` is neither an allocator nor a factory. + + Arguments: + obj: The allocator or allocator factory. + default: Fallback allocator. + strict: If `True`, raise an exception if there is no way to get a valid allocator + from `obj` or `default`. + + Returns: + A field buffer allocator. + + Raises: + TypeError: If `obj` is neither a field allocator nor a field allocator factory and no default + is provided in `strict` mode. + """ + if is_field_allocator(obj): + return obj + elif is_field_allocator_factory(obj): + return obj.__gt_allocator__ + elif not strict or is_field_allocator(default): + return default + else: + raise TypeError(f"Object {obj} is neither a field allocator nor a field allocator factory") + + +@dataclasses.dataclass(frozen=True) +class BaseFieldBufferAllocator(FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]): + """Parametrizable field buffer allocator base class.""" + + device_type: core_defs.DeviceTypeT + array_ns: core_allocators.ValidNumPyLikeAllocationNS + layout_mapper: FieldLayoutMapper + byte_alignment: int + + @property + def __gt_device_type__(self) -> core_defs.DeviceTypeT: + return self.device_type + + def __gt_allocate__( + self, + domain: common.Domain, + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int = 0, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + shape = domain.shape + layout_map = self.layout_mapper(domain.dims) + # TODO(egparedes): add support for non-empty aligned index values + assert aligned_index is None + + return core_allocators.NDArrayBufferAllocator(self.device_type, self.array_ns).allocate( + shape, dtype, device_id, layout_map, self.byte_alignment, aligned_index + ) + + +if TYPE_CHECKING: + __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[ + FieldBufferAllocatorProtocol + ] = BaseFieldBufferAllocator + + +def horizontal_first_layout_mapper( + dims: Sequence[common.Dimension], +) -> core_allocators.BufferLayoutMap: + """Map dimensions to a buffer layout making horizonal dims change the slowest (i.e. larger strides).""" + + def pos_of_kind(kind: common.DimensionKind) -> list[int]: + return [i for i, dim in enumerate(dims) if dim.kind == kind] + + horizontals = pos_of_kind(common.DimensionKind.HORIZONTAL) + verticals = pos_of_kind(common.DimensionKind.VERTICAL) + locals_ = pos_of_kind(common.DimensionKind.LOCAL) + + layout_map = [0] * len(dims) + for i, pos in enumerate(horizontals + verticals + locals_): + layout_map[pos] = len(dims) - 1 - i + + valid_layout_map = tuple(layout_map) + assert core_allocators.is_valid_layout_map(valid_layout_map) + + return valid_layout_map + + +if TYPE_CHECKING: + __horizontal_first_layout_mapper: FieldLayoutMapper = horizontal_first_layout_mapper + + +#: Registry of default allocators for each device type. +device_allocators: dict[core_defs.DeviceType, FieldBufferAllocatorProtocol] = {} + + +assert core_allocators.is_valid_nplike_allocation_ns(np) +np_alloc_ns: core_allocators.ValidNumPyLikeAllocationNS = np # Just for static type checking + + +class StandardCPUFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CPUDeviceTyping]): + """A field buffer allocator for CPU devices that uses a horizontal-first layout mapper and 64-byte alignment.""" + + def __init__(self) -> None: + super().__init__( + device_type=core_defs.DeviceType.CPU, + array_ns=np_alloc_ns, + layout_mapper=horizontal_first_layout_mapper, + byte_alignment=64, + ) + + +device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator() + +assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU]) + + +@dataclasses.dataclass(frozen=True) +class InvalidFieldBufferAllocator(FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]): + """A field buffer allocator that always raises an exception.""" + + device_type: core_defs.DeviceTypeT + exception: Exception + + @property + def __gt_device_type__(self) -> core_defs.DeviceTypeT: + return self.device_type + + def __gt_allocate__( + self, + domain: common.Domain, + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int = 0, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + raise self.exception + + +if CUPY_DEVICE is not None: + cp_alloc_ns: core_allocators.ValidNumPyLikeAllocationNS = cp # Just for static type checking + assert core_allocators.is_valid_nplike_allocation_ns(cp_alloc_ns) + + if CUPY_DEVICE is core_defs.DeviceType.CUDA: + + class CUDAFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CUDADeviceTyping]): + def __init__(self) -> None: + super().__init__( + device_type=core_defs.DeviceType.CUDA, + array_ns=cp_alloc_ns, + layout_mapper=horizontal_first_layout_mapper, + byte_alignment=128, + ) + + device_allocators[core_defs.DeviceType.CUDA] = CUDAFieldBufferAllocator() + + else: + + class ROCMFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.ROCMDeviceTyping]): + def __init__(self) -> None: + super().__init__( + device_type=core_defs.DeviceType.ROCM, + array_ns=cp_alloc_ns, + layout_mapper=horizontal_first_layout_mapper, + byte_alignment=128, + ) + + device_allocators[core_defs.DeviceType.ROCM] = ROCMFieldBufferAllocator() + +else: + + class InvalidGPUFielBufferAllocator(InvalidFieldBufferAllocator[core_defs.CUDADeviceTyping]): + def __init__(self) -> None: + super().__init__( + device_type=core_defs.DeviceType.CUDA, + exception=RuntimeError("Missing `cupy` dependency for GPU allocation"), + ) + + +StandardGPUFieldBufferAllocator: Final[type[FieldBufferAllocatorProtocol]] = cast( + type[FieldBufferAllocatorProtocol], + type(device_allocators[CUPY_DEVICE]) if CUPY_DEVICE else InvalidGPUFielBufferAllocator, +) + + +def allocate( + domain: common.DomainLike, + dtype: core_defs.DType[core_defs.ScalarT], + *, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[FieldBufferAllocationUtil] = None, + device: Optional[core_defs.Device] = None, +) -> core_allocators.TensorBuffer: + """ + Allocate a TensorBuffer for the given domain and device or allocator. + + The arguments `device` and `allocator` are mutually exclusive. + If `device` is specified, the corresponding default allocator + (defined in :data:`device_allocators`) is used. + + Arguments: + domain: The domain which should be backed by the allocated tensor buffer. + dtype: Data type. + aligned_index: N-dimensional index of the first aligned element + allocator: The allocator to use for the allocation. + device: The device to allocate the tensor buffer on (using the default + allocator for this kind of device from :data:`device_allocators`). + + Returns: + The allocated tensor buffer. + + Raises: + ValueError + If illegal or inconsistent arguments are specified. + + """ + if device is None and allocator is None: + raise ValueError("No 'device' or 'allocator' specified") + actual_allocator = get_allocator(allocator) + if actual_allocator is None: + assert device is not None # for mypy + actual_allocator = device_allocators[device.device_type] + elif device is None: + device = core_defs.Device(actual_allocator.__gt_device_type__, 0) + elif device.device_type != actual_allocator.__gt_device_type__: + raise ValueError(f"Device {device} and allocator {actual_allocator} are incompatible") + + return actual_allocator.__gt_allocate__( + domain=common.domain(domain), + dtype=dtype, + device_id=device.device_id, + aligned_index=aligned_index, + ) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index b85239cd0a..ffaa410563 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -80,15 +80,18 @@ def __str__(self): return f"{self.value}[{self.kind}]" -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False) class UnitRange(Sequence[int], Set[int]): """Range from `start` to `stop` with step size one.""" start: int stop: int - def __post_init__(self): - if self.stop <= self.start: + def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None: + if start < stop: + object.__setattr__(self, "start", int(start)) + object.__setattr__(self, "stop", int(stop)) + else: # make UnitRange(0,0) the single empty UnitRange object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) @@ -142,7 +145,12 @@ def __str__(self) -> str: return f"({self.start}:{self.stop})" -RangeLike: TypeAlias = UnitRange | range | tuple[int, int] +RangeLike: TypeAlias = ( + UnitRange + | range + | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] + | core_defs.IntegralScalar +) def unit_range(r: RangeLike) -> UnitRange: @@ -152,9 +160,17 @@ def unit_range(r: RangeLike) -> UnitRange: if r.step != 1: raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") return UnitRange(r.start, r.stop) - if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): + # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) + # once the related mypy bug (#16358) gets fixed + if ( + isinstance(r, tuple) + and isinstance(r[0], core_defs.INTEGRAL_TYPES) + and isinstance(r[1], core_defs.INTEGRAL_TYPES) + ): return UnitRange(r[0], r[1]) - raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.") + if isinstance(r, core_defs.INTEGRAL_TYPES): + return UnitRange(0, cast(core_defs.IntegralScalar, r)) + raise ValueError(f"`{r!r}` cannot be interpreted as `UnitRange`.") IntIndex: TypeAlias = int | core_defs.IntegralScalar @@ -274,6 +290,10 @@ def __init__( def __len__(self) -> int: return len(self.ranges) + @property + def shape(self) -> tuple[int, ...]: + return tuple(len(r) for r in self.ranges) + @overload def __getitem__(self, index: int) -> NamedRange: ... @@ -350,12 +370,23 @@ def domain(domain_like: DomainLike) -> Domain: >>> domain({I: (2, 4), J: (3, 5)}) Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + + >>> domain(((I, 2), (J, 4))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(0, 2), UnitRange(0, 4))) + + >>> domain({I: 2, J: 4}) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(0, 2), UnitRange(0, 4))) """ if isinstance(domain_like, Domain): return domain_like if isinstance(domain_like, Sequence): return Domain(*tuple(named_range(d) for d in domain_like)) if isinstance(domain_like, Mapping): + if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): + return Domain( + dims=tuple(domain_like.keys()), + ranges=tuple(UnitRange(0, s) for s in domain_like.values()), # type: ignore[arg-type] # type of `s` is checked in condition + ) return Domain( dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), @@ -383,20 +414,30 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... +# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol class NextGTDimsInterface(Protocol): """ - A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. + Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions. - The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`, - where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` . + The dimension names are objects of type :class:`Dimension`, in contrast to + :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics, + see :class:`~gt4py._core.definitions.GTDimsInterface` . """ - # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way @property def __gt_dims__(self) -> tuple[Dimension, ...]: ... +# TODO(egparedes): add support for this new protocol in the cartesian module +class GTFieldInterface(Protocol): + """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.""" + + @property + def __gt_domain__(self) -> Domain: + ... + + @extended_runtime_checkable class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @@ -671,7 +712,7 @@ class FieldBuiltinFuncRegistry: def __init_subclass__(cls, **kwargs): cls._builtin_func_map = collections.ChainMap( - {}, # New empty `dict`` for new registrations on this class + {}, # New empty `dict` for new registrations on this class *[ c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order for c in cls.__mro__ diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py new file mode 100644 index 0000000000..30ef8452aa --- /dev/null +++ b/src/gt4py/next/constructors.py @@ -0,0 +1,297 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Optional, cast + +import gt4py._core.definitions as core_defs +import gt4py.eve as eve +import gt4py.eve.extended_typing as xtyping +import gt4py.next.allocators as next_allocators +import gt4py.next.common as common +import gt4py.next.embedded.nd_array_field as nd_array_field +import gt4py.storage.cartesian.utils as storage_utils + + +@eve.utils.with_fluid_partial +def empty( + domain: common.DomainLike, + dtype: core_defs.DTypeLike = core_defs.Float64DType(()), + *, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[next_allocators.FieldBufferAllocationUtil] = None, + device: Optional[core_defs.Device] = None, +) -> nd_array_field.NdArrayField: + """Create a `Field` of uninitialized (undefined) values using the given (or device-default) allocator. + + This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. + + Arguments: + domain: Definition of the domain of the field (which fix the shape of the allocated field buffer). + See :class:`gt4py.next.common.Domain` for details. + dtype: Definition of the data type of the field. Defaults to `float64`. + + Keyword Arguments: + aligned_index: Index in the definition domain which should be used as reference + point for memory aligment computations. It can be set to the most common origin + of computations in this domain (if known) for performance reasons. + allocator: The allocator or allocator factory (e.g. backend) used for memory buffer + allocation, which knows how to optimize the memory layout for a given device. + Required if `device` is `None`. If both are valid, `allocator` will be chosen over + the default device allocator. + device: The device (CPU, type of accelerator) to optimize the memory layout for. + Required if `allocator` is `None` and will cause the default device allocator + to be used in that case. + + Returns: + A field, backed by a buffer with memory layout as specified by allocator and alignment requirements. + + Raises: + ValueError + If illegal or inconsistent arguments are specified. + + Examples: + Initialize a field in one dimension with a backend and a range domain: + + >>> from gt4py import next as gtx + >>> from gt4py.next.program_processors.runners import roundtrip + >>> IDim = gtx.Dimension("I") + >>> a = gtx.empty({IDim: range(3, 10)}, allocator=roundtrip.backend) + >>> a.shape + (7,) + + Initialize with a device and an integer domain. It works like a shape with named dimensions: + + >>> from gt4py._core import definitions as core_defs + >>> JDim = gtx.Dimension("J") + >>> b = gtx.empty({IDim: 3, JDim: 3}, int, device=core_defs.Device(core_defs.DeviceType.CPU, 0)) + >>> b.shape + (3, 3) + """ + dtype = core_defs.dtype(dtype) + buffer = next_allocators.allocate( + domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device + ) + res = common.field(buffer.ndarray, domain=domain) + assert common.is_mutable_field(res) + assert isinstance(res, nd_array_field.NdArrayField) + return res + + +@eve.utils.with_fluid_partial +def zeros( + domain: common.DomainLike, + dtype: core_defs.DTypeLike = core_defs.Float64DType(()), + *, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, + device: Optional[core_defs.Device] = None, +) -> nd_array_field.NdArrayField: + """Create a Field containing all zeros using the given (or device-default) allocator. + + This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. + See :func:`empty` for further details about the meaning of the arguments. + + Examples: + >>> from gt4py import next as gtx + >>> from gt4py.next.program_processors.runners import roundtrip + >>> IDim = gtx.Dimension("I") + >>> gtx.zeros({IDim: range(3, 10)}, allocator=roundtrip.backend).ndarray + array([0., 0., 0., 0., 0., 0., 0.]) + """ + field = empty( + domain=domain, + dtype=dtype, + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + field[...] = field.dtype.scalar_type(0) + return field + + +@eve.utils.with_fluid_partial +def ones( + domain: common.DomainLike, + dtype: core_defs.DTypeLike = core_defs.Float64DType(()), + *, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, + device: Optional[core_defs.Device] = None, +) -> nd_array_field.NdArrayField: + """Create a Field containing all ones using the given (or device-default) allocator. + + This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. + See :func:`empty` for further details about the meaning of the arguments. + + Examples: + >>> from gt4py import next as gtx + >>> from gt4py.next.program_processors.runners import roundtrip + >>> IDim = gtx.Dimension("I") + >>> gtx.ones({IDim: range(3, 10)}, allocator=roundtrip.backend).ndarray + array([1., 1., 1., 1., 1., 1., 1.]) + """ + field = empty( + domain=domain, + dtype=dtype, + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + field[...] = field.dtype.scalar_type(1) + return field + + +@eve.utils.with_fluid_partial +def full( + domain: common.DomainLike, + fill_value: core_defs.Scalar, + dtype: Optional[core_defs.DTypeLike] = None, + *, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, + device: Optional[core_defs.Device] = None, +) -> nd_array_field.NdArrayField: + """Create a Field where all values are set to `fill_value` using the given (or device-default) allocator. + + This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. + See :func:`empty` for further details about the meaning of the arguments. + + Arguments: + domain: Definition of the domain of the field (and consequently of the shape of the allocated field buffer). + fill_value: Each point in the field will be initialized to this value. + dtype: Definition of the data type of the field. Defaults to the dtype of `fill_value`. + + Examples: + >>> from gt4py import next as gtx + >>> from gt4py.next.program_processors.runners import roundtrip + >>> IDim = gtx.Dimension("I") + >>> gtx.full({IDim: 3}, 5, allocator=roundtrip.backend).ndarray + array([5, 5, 5]) + """ + field = empty( + domain=domain, + dtype=dtype if dtype is not None else core_defs.dtype(type(fill_value)), + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + field[...] = field.dtype.scalar_type(fill_value) + return field + + +@eve.utils.with_fluid_partial +def as_field( + domain: common.DomainLike | Sequence[common.Dimension], + data: core_defs.NDArrayObject, + dtype: Optional[core_defs.DTypeLike] = None, + *, + origin: Optional[Mapping[common.Dimension, int]] = None, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, + device: Optional[core_defs.Device] = None, + # copy=False, TODO +) -> nd_array_field.NdArrayField: + """Create a Field from an array-like object using the given (or device-default) allocator. + + This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. + See :func:`empty` for further details about the meaning of the extra keyword arguments. + + Parameters: + domain: Definition of the domain of the field (and consequently of the shape of the allocated field buffer). + In addition to the values allowed in `empty`, it can also just be a sequence of dimensions, + in which case the sizes of each dimension will then be taken from the shape of `data`. + data: Array like data object to initialize the field with + dtype: Definition of the data type of the field. Defaults to the same as `data`. + + Keyword Arguments: + origin: Only allowed if `domain` is a sequence of dimensions. The indicated index in `data` + will be the zero point of the resulting field. + allocator: Fully optional, in contrast to `empty`. + device: Fully optional, in contrast to `empty`, defaults to the same device as `data`. + + Examples: + >>> import numpy as np + >>> from gt4py import next as gtx + >>> IDim = gtx.Dimension("I") + >>> xdata = np.array([1, 2, 3]) + + Automatic domain from just dimensions: + + >>> a = gtx.as_field([IDim], xdata) + >>> a.ndarray + array([1, 2, 3]) + >>> a.domain.ranges[0] + UnitRange(0, 3) + + Shifted domain using origin: + + >>> b = gtx.as_field([IDim], xdata, origin={IDim: 1}) + >>> b.domain.ranges[0] + UnitRange(-1, 2) + + Equivalent domain fully specified: + + >>> gtx.as_field({IDim: range(-1, 2)}, xdata).domain.ranges[0] + UnitRange(-1, 2) + """ + if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): + domain = cast(Sequence[common.Dimension], domain) + if len(domain) != data.ndim: + raise ValueError( + f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + ) + if origin: + domain_dims = set(domain) + if unknown_dims := set(origin.keys()) - domain_dims: + raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}") + else: + origin = {} + actual_domain = common.domain( + [ + (d, (-(start_offset := origin.get(d, 0)), s - start_offset)) + for d, s in zip(domain, data.shape) + ] + ) + else: + if origin: + raise ValueError(f"Cannot specify origin for domain {domain}") + actual_domain = common.domain(cast(common.DomainLike, domain)) + + # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has + # already the correct layout and device. + shape = storage_utils.asarray(data).shape + if shape != actual_domain.shape: + raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + if dtype is None: + dtype = storage_utils.asarray(data).dtype + dtype = core_defs.dtype(dtype) + assert dtype.tensor_shape == () # TODO + + if allocator is device is None and xtyping.supports_dlpack(data): + device = core_defs.Device(*data.__dlpack_device__()) + + field = empty( + domain=actual_domain, + dtype=dtype, + aligned_index=aligned_index, + allocator=allocator, + device=device, + ) + + field[...] = field.array_ns.asarray(data) + + return field diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fcaa09e7eb..527197e0bc 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -40,7 +40,7 @@ def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_unary_op(a: _BaseNdArrayField) -> common.Field: + def _builtin_unary_op(a: NdArrayField) -> common.Field: xp = a.__class__.array_ns op = getattr(xp, array_builtin_name) new_data = op(a.ndarray) @@ -52,7 +52,7 @@ def _builtin_unary_op(a: _BaseNdArrayField) -> common.Field: def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: + def _builtin_binary_op(a: NdArrayField, b: common.Field) -> common.Field: xp = a.__class__.array_ns op = getattr(xp, array_builtin_name) if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): @@ -81,7 +81,7 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField( +class NdArrayField( common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry ): """ @@ -136,7 +136,7 @@ def from_array( *, domain: common.DomainLike, dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike - ) -> _BaseNdArrayField: + ) -> NdArrayField: domain = common.domain(domain) xp = cls.array_ns @@ -157,7 +157,7 @@ def from_array( return cls(domain, array) - def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: + def remap(self: NdArrayField, connectivity) -> NdArrayField: raise NotImplementedError() def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: @@ -165,7 +165,7 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala new_buffer = self.ndarray[buffer_slice] if len(new_domain) == 0: - assert core_defs.is_scalar_type(new_buffer) + # TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here else: return self.__class__.from_array(new_buffer, domain=new_domain) @@ -196,7 +196,7 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( self, other @@ -205,14 +205,14 @@ def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: __rand__ = __and__ - def __or__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") __ror__ = __or__ - def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( self, other @@ -221,7 +221,7 @@ def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: __rxor__ = __xor__ - def __invert__(self) -> _BaseNdArrayField: + def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_unary_array_field_intrinsic_func("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") @@ -243,8 +243,8 @@ def _slice( # -- Specialized implementations for intrinsic operations on array fields -- -_BaseNdArrayField.register_builtin_func(fbuiltins.abs, _BaseNdArrayField.__abs__) # type: ignore[attr-defined] -_BaseNdArrayField.register_builtin_func(fbuiltins.power, _BaseNdArrayField.__pow__) # type: ignore[attr-defined] +NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] +NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] # TODO gamma for name in ( @@ -254,23 +254,23 @@ def _slice( ): if name in ["abs", "power", "gamma"]: continue - _BaseNdArrayField.register_builtin_func( + NdArrayField.register_builtin_func( getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name) ) -_BaseNdArrayField.register_builtin_func( +NdArrayField.register_builtin_func( fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined] ) -_BaseNdArrayField.register_builtin_func( +NdArrayField.register_builtin_func( fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined] ) -_BaseNdArrayField.register_builtin_func( +NdArrayField.register_builtin_func( fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] ) def _np_cp_setitem( - self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], + self: NdArrayField[common.DimsT, core_defs.ScalarT], index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: @@ -293,7 +293,7 @@ def _np_cp_setitem( @dataclasses.dataclass(frozen=True) -class NumPyArrayField(_BaseNdArrayField): +class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np __setitem__ = _np_cp_setitem @@ -306,7 +306,7 @@ class NumPyArrayField(_BaseNdArrayField): _nd_array_implementations.append(cp) @dataclasses.dataclass(frozen=True) - class CuPyArrayField(_BaseNdArrayField): + class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp __setitem__ = _np_cp_setitem @@ -318,7 +318,7 @@ class CuPyArrayField(_BaseNdArrayField): _nd_array_implementations.append(jnp) @dataclasses.dataclass(frozen=True) - class JaxArrayField(_BaseNdArrayField): + class JaxArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = jnp def __setitem__( @@ -355,7 +355,7 @@ def _builtins_broadcast( raise AssertionError("Scalar case not reachable from `fbuiltins.broadcast`.") -_BaseNdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) +NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) def _get_slices_from_domain_slice( diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 12ab3955ab..2d12331513 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -30,8 +30,9 @@ from devtools import debug from gt4py._core import definitions as core_defs +from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.eve.utils import UIDGenerator +from gt4py.next import allocators as next_allocators from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import ( dialect_ast_enums, @@ -214,6 +215,15 @@ def __post_init__(self): f"The following closure variables are undefined: {', '.join(undefined_symbols)}" ) + @functools.cached_property + def __gt_allocator__( + self, + ) -> next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: + if self.backend: + return self.backend.__gt_allocator__ + else: + raise RuntimeError(f"Program {self} does not have a backend set.") + def with_backend(self, backend: ppi.ProgramExecutor) -> Program: return dataclasses.replace(self, backend=backend) @@ -609,7 +619,7 @@ def as_program( # with the out argument of the program we generate here. loc = self.foast_node.location - param_sym_uids = UIDGenerator() # use a new UID generator to allow caching + param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching type_ = self.__gt_type__() params_decl: list[past.Symbol] = [ @@ -790,8 +800,8 @@ def scan_operator( >>> from gt4py.next.iterator import embedded >>> embedded._column_range = 1 # implementation detail >>> KDim = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL) - >>> inp = gtx.np_as_located_field(KDim)(np.ones((10,))) - >>> out = gtx.np_as_located_field(KDim)(np.zeros((10,))) + >>> inp = gtx.as_field([KDim], np.ones((10,))) + >>> out = gtx.as_field([KDim], np.zeros((10,))) >>> @gtx.scan_operator(axis=KDim, forward=True, init=0.) ... def scan_operator(carry: float, val: float) -> float: ... return carry+val diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3d159eaae7..674f99f61c 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -23,6 +23,7 @@ import itertools import math import sys +import warnings from typing import ( Any, Callable, @@ -1015,6 +1016,8 @@ def _shift_field_indices( def np_as_located_field( *axes: common.Dimension, origin: Optional[dict[common.Dimension, int]] = None ) -> Callable[[np.ndarray], common.Field]: + warnings.warn("`np_as_located_field()` is deprecated, use `gtx.as_field()`", DeprecationWarning) + origin = origin or {} def _maker(a) -> common.Field: @@ -1063,7 +1066,7 @@ def dtype(self) -> core_defs.Int32DType: @property def ndarray(self) -> core_defs.NDArrayObject: - return AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get `ndarray` of an infinite Field.") def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) @@ -1169,7 +1172,7 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]: @property def ndarray(self) -> core_defs.NDArrayObject: - return AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get `ndarray` of an infinite Field.") def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 2952bf3465..f9fa154641 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -20,5 +20,5 @@ @program_formatter -def format_sourcecode(program: itir.FencilDefinition, *arg: Any, **kwargs: Any) -> str: +def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: return generate(program, **kwargs) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index b6afb88759..4f4a15f908 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -14,15 +14,23 @@ from typing import Any -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.pretty_parser import pparse -from gt4py.next.iterator.pretty_printer import pformat -from gt4py.next.program_processors.processor_interface import program_formatter +import gt4py.eve as eve +import gt4py.next.iterator.ir as itir +import gt4py.next.iterator.pretty_parser as pretty_parser +import gt4py.next.iterator.pretty_printer as pretty_printer +import gt4py.next.program_processors.processor_interface as ppi -@program_formatter -def pretty_format_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - pretty = pformat(program) - parsed = pparse(pretty) - assert parsed == program +class _RemoveITIRSymTypes(eve.NodeTranslator): + def visit_Sym(self, node: itir.Sym) -> itir.Sym: + return itir.Sym(id=node.id, dtype=None, kind=None) + + +@ppi.program_formatter +def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: + # remove types from ITIR as they are not supported for the roundtrip + root = _RemoveITIRSymTypes().visit(program) + pretty = pretty_printer.pformat(root) + parsed = pretty_parser.pparse(pretty) + assert parsed == root return pretty diff --git a/src/gt4py/next/program_processors/formatters/type_check.py b/src/gt4py/next/program_processors/formatters/type_check.py index 07cbc89ebd..8f17b8cf98 100644 --- a/src/gt4py/next/program_processors/formatters/type_check.py +++ b/src/gt4py/next/program_processors/formatters/type_check.py @@ -18,7 +18,7 @@ @program_formatter -def check(program: itir.FencilDefinition, *args, **kwargs) -> str: +def check_type_inference(program: itir.FencilDefinition, *args, **kwargs) -> str: type_inference.pprint(type_inference.infer(program, offset_provider=kwargs["offset_provider"])) transformed = apply_common_transforms( program, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] diff --git a/src/gt4py/next/program_processors/otf_compile_executor.py b/src/gt4py/next/program_processors/otf_compile_executor.py index cd08c16933..8dff34a35d 100644 --- a/src/gt4py/next/program_processors/otf_compile_executor.py +++ b/src/gt4py/next/program_processors/otf_compile_executor.py @@ -12,12 +12,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + import dataclasses from typing import Any, Generic, Optional, TypeVar -from gt4py.next.iterator import ir as itir +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators +import gt4py.next.iterator.ir as itir +import gt4py.next.program_processors.processor_interface as ppi from gt4py.next.otf import languages, recipes, stages, workflow -from gt4py.next.program_processors import processor_interface as ppi SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL) @@ -54,3 +58,26 @@ def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None @property def __name__(self) -> str: return self.name or repr(self) + + +@dataclasses.dataclass(frozen=True) +class OTFBackend(Generic[core_defs.DeviceTypeT]): + executor: ppi.ProgramExecutor + allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] + + def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: + self.executor.__call__(program, *args, **kwargs) + + @property + def __name__(self) -> str: + return getattr(self.executor, "__name__", None) or repr(self) + + @property + def kind(self) -> type[ppi.ProgramExecutor]: + return self.executor.kind + + @property + def __gt_allocator__( + self, + ) -> next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: + return self.allocator diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index b39438937e..d9f8b36301 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -26,21 +26,25 @@ """ from __future__ import annotations -from typing import Callable, Protocol, TypeGuard, TypeVar, cast +import functools +from collections.abc import Sequence +from typing import Any, Callable, Literal, Optional, Protocol, TypeGuard, TypeVar, cast -from gt4py.next.iterator import ir as itir +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators +import gt4py.next.iterator.ir as itir OutputT = TypeVar("OutputT", covariant=True) ProcessorKindT = TypeVar("ProcessorKindT", bound="ProgramProcessor", covariant=True) -class ProgramProcessorFunction(Protocol[OutputT]): +class ProgramProcessorCallable(Protocol[OutputT]): def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: ... -class ProgramProcessor(ProgramProcessorFunction[OutputT], Protocol[OutputT, ProcessorKindT]): +class ProgramProcessor(ProgramProcessorCallable[OutputT], Protocol[OutputT, ProcessorKindT]): @property def kind(self) -> type[ProcessorKindT]: ... @@ -52,46 +56,133 @@ def kind(self) -> type[ProgramFormatter]: return ProgramFormatter -def program_formatter(func: ProgramProcessorFunction[str]) -> ProgramFormatter: +def make_program_processor( + func: ProgramProcessorCallable[OutputT], + kind: type[ProcessorKindT], + *, + name: Optional[str] = None, + accept_args: None | int | Literal["all"] = "all", + accept_kwargs: None | Sequence[str] | Literal["all"] = "all", +) -> ProgramProcessor[OutputT, ProcessorKindT]: + """ + Create a program processor from a callable function. + + Args: + func: The callable function to be wrapped as a program processor. + kind: The type of the processor. + name: The name of the processor. + accept_args: The number of positional arguments to accept, or "all" to accept all. + accept_kwargs: The names of the keyword arguments to accept, or "all" to accept all. + + Returns: + A program processor that wraps the given function. + + Raises: + ValueError: If the value of `accept_args` or `accept_kwargs` is invalid. + """ + args_filter: Callable[[Sequence], Sequence] + if accept_args is None: + args_filter = lambda args: () # noqa: E731 # use def instead of named lambdas + elif accept_args == "all": + args_filter = lambda args: args # noqa: E731 + elif isinstance(accept_args, int): + if accept_args < 0: + raise ValueError( + f"Number of accepted arguments cannot be a negative number ({accept_args})" + ) + args_filter = lambda args: args[:accept_args] # type: ignore[misc] # noqa: E731 + else: + raise ValueError(f"Invalid ({accept_args}) accept_args value") + + filtered_kwargs: Callable[[dict[str, Any]], dict[str, Any]] + if accept_kwargs is None: + filtered_kwargs = lambda kwargs: {} # noqa: E731 # use def instead of named lambdas + elif accept_kwargs == "all": # don't swap with 'isinstance(..., Sequence)' + filtered_kwargs = lambda kwargs: kwargs # noqa: E731 + elif isinstance(accept_kwargs, Sequence): + if not all(isinstance(a, str) for a in accept_kwargs): + raise ValueError(f"Provided invalid list of keyword argument names ({accept_args})") + filtered_kwargs = lambda kwargs: { # noqa: E731 + key: value for key, value in kwargs.items() if key in accept_kwargs # type: ignore[operator] # key in accept_kwargs + } + else: + raise ValueError(f"Invalid ({accept_kwargs}) 'accept_kwargs' value") + + @functools.wraps(func) + def _wrapper(program: itir.FencilDefinition, *args, **kwargs) -> OutputT: + return func(program, *args_filter(args), **filtered_kwargs(kwargs)) + + if name is not None: + _wrapper.__name__ = name + + # this operation effectively changes the type of the returned object, + # which is the intention here + _wrapper.kind = kind # type: ignore[attr-defined] + + return cast(ProgramProcessor[OutputT, ProcessorKindT], _wrapper) + + +def program_formatter( + func: ProgramProcessorCallable[str], + *, + name: Optional[str] = None, + accept_args: None | int | Literal["all"] = "all", + accept_kwargs: Sequence[str] | None | Literal["all"] = "all", +) -> ProgramFormatter: """ Turn a function that formats a program as a string into a ProgramFormatter. Examples: - --------- - >>> @program_formatter - ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: - ... '''A very useless fencil formatter.''' - ... return "foo" + >>> @program_formatter + ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: + ... '''A very useless fencil formatter.''' + ... return "foo" - >>> ensure_processor_kind(format_foo, ProgramFormatter) + >>> ensure_processor_kind(format_foo, ProgramFormatter) """ - # this operation effectively changes the type of func and that is the intention here - func.kind = ProgramFormatter # type: ignore[attr-defined] - return cast(ProgramProcessor[str, ProgramFormatter], func) + return make_program_processor( + func, + ProgramFormatter, # type: ignore[type-abstract] # ProgramFormatter is abstract + name=name, + accept_args=accept_args, + accept_kwargs=accept_kwargs, + ) -class ProgramExecutor(ProgramProcessor[None, "ProgramExecutor"], Protocol): +class ProgramExecutor(ProgramProcessor[None, "ProgramExecutor"]): @property def kind(self) -> type[ProgramExecutor]: return ProgramExecutor -def program_executor(func: ProgramProcessorFunction[None]) -> ProgramExecutor: +def program_executor( + func: ProgramProcessorCallable[None], + *, + name: Optional[str] = None, + accept_args: None | int | Literal["all"] = "all", + accept_kwargs: Sequence[str] | None | Literal["all"] = "all", +) -> ProgramExecutor: """ Turn a function that executes a program into a ``ProgramExecutor``. Examples: - --------- - >>> @program_executor - ... def badly_execute(fencil: itir.FencilDefinition, *args, **kwargs) -> None: - ... '''A useless and incorrect fencil executor.''' - ... pass + >>> @program_executor + ... def badly_execute(fencil: itir.FencilDefinition, *args, **kwargs) -> None: + ... '''A useless and incorrect fencil executor.''' + ... pass - >>> ensure_processor_kind(badly_execute, ProgramExecutor) + >>> ensure_processor_kind(badly_execute, ProgramExecutor) """ - # this operation effectively changes the type of func and that is the intention here - func.kind = ProgramExecutor # type: ignore[attr-defined] - return cast(ProgramExecutor, func) + return cast( + ProgramExecutor, + make_program_processor( + func, + ProgramExecutor, + name=name, + accept_args=accept_args, + accept_kwargs=accept_kwargs, + ), + ) def is_processor_kind( @@ -105,3 +196,25 @@ def ensure_processor_kind( ) -> None: if not is_processor_kind(obj, kind): raise TypeError(f"{obj} is not a {kind.__name__}!") + + +class ProgramBackend( + ProgramProcessor[None, "ProgramExecutor"], + next_allocators.FieldBufferAllocatorFactoryProtocol[core_defs.DeviceTypeT], + Protocol[core_defs.DeviceTypeT], +): + ... + + +def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: + return is_processor_kind( + obj, ProgramExecutor # type: ignore[type-abstract] # ProgramExecutor is abstract + ) and next_allocators.is_field_allocator_factory(obj) + + +def is_program_backend_for( + obj: Callable, device: core_defs.DeviceTypeT +) -> TypeGuard[ProgramBackend[core_defs.DeviceTypeT]]: + return is_processor_kind( + obj, ProgramExecutor # type: ignore[type-abstract] # ProgramExecutor is abstract + ) and next_allocators.is_field_allocator_factory_for(obj, device) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index be63d6809d..9f67cb26da 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -19,7 +19,9 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.transformation.auto import auto_optimize as autoopt +import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir +import gt4py.next.program_processors.otf_compile_executor as otf_exec from gt4py.next.common import Dimension, Domain, UnitRange, is_field from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms @@ -235,22 +237,43 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: @program_executor -def run_dace(program: itir.FencilDefinition, *args, **kwargs) -> None: - run_on_gpu = any(not isinstance(arg.ndarray, np.ndarray) for arg in args if is_field(arg)) - if run_on_gpu: - if cp is None: - raise RuntimeError( - f"Non-numpy field argument passed to program {program.id} but module cupy not installed" - ) - - if not all(isinstance(arg.ndarray, cp.ndarray) for arg in args if is_field(arg)): - raise RuntimeError("Execution on GPU requires all fields to be stored as cupy arrays") - +def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, *args, **kwargs, - build_cache=_build_cache_gpu if run_on_gpu else _build_cache_cpu, + build_cache=_build_cache_cpu, build_type=_build_type, - run_on_gpu=run_on_gpu, + run_on_gpu=False, ) + + +run_dace_cpu = otf_exec.OTFBackend( + executor=_run_dace_cpu, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), +) + +if cp: + + @program_executor + def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_gpu, + build_type=_build_type, + run_on_gpu=True, + ) + +else: + + @program_executor + def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + raise RuntimeError("Missing `cupy` dependency for GPU execution.") + + +run_dace_gpu = otf_exec.OTFBackend( + executor=_run_dace_gpu, + allocator=next_allocators.StandardGPUFieldBufferAllocator(), +) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 651fb43fa7..2f06d17c7f 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -12,13 +12,25 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any +from __future__ import annotations -from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.processor_interface import program_executor -from gt4py.next.program_processors.runners import roundtrip +from typing import TYPE_CHECKING, Any +import gt4py.next.program_processors.otf_compile_executor as otf_compile_executor +import gt4py.next.program_processors.processor_interface as ppi +import gt4py.next.program_processors.runners.roundtrip as roundtrip -@program_executor + +if TYPE_CHECKING: + import gt4py.next.iterator.ir as itir + + +@ppi.program_executor def executor(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: - roundtrip.executor(program, *args, dispatch_backend=roundtrip.executor, **kwargs) + roundtrip.execute_roundtrip(program, *args, dispatch_backend=roundtrip.executor, **kwargs) + + +backend = otf_compile_executor.OTFBackend( + executor=executor, + allocator=roundtrip.backend.allocator, +) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 35c10fe353..7233e7a893 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -16,7 +16,8 @@ import numpy.typing as npt -from gt4py._core import definitions as core_defs +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import common from gt4py.next.iterator.transforms import LiftMode @@ -129,29 +130,66 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ) -run_gtfn = otf_compile_executor.OTFCompileExecutor( +gtfn_executor = otf_compile_executor.OTFCompileExecutor( name="run_gtfn", otf_workflow=GTFN_DEFAULT_WORKFLOW ) +run_gtfn = otf_compile_executor.OTFBackend( + executor=gtfn_executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), +) -run_gtfn_imperative = otf_compile_executor.OTFCompileExecutor( +gtfn_imperative_executor = otf_compile_executor.OTFCompileExecutor( name="run_gtfn_imperative", - otf_workflow=run_gtfn.otf_workflow.replace( - translation=run_gtfn.otf_workflow.translation.replace(use_imperative_backend=True), + otf_workflow=gtfn_executor.otf_workflow.replace( + translation=gtfn_executor.otf_workflow.translation.replace(use_imperative_backend=True), ), ) +run_gtfn_imperative = otf_compile_executor.OTFBackend( + executor=gtfn_imperative_executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), +) -run_gtfn_cached = otf_compile_executor.CachedOTFCompileExecutor( +# TODO(ricoh): add API for converting an executor to a cached version of itself and vice versa +gtfn_cached_executor = otf_compile_executor.CachedOTFCompileExecutor( name="run_gtfn_cached", - otf_workflow=workflow.CachedStep(step=run_gtfn.otf_workflow, hash_function=compilation_hash), -) # todo(ricoh): add API for converting an executor to a cached version of itself and vice versa + otf_workflow=workflow.CachedStep( + step=gtfn_executor.otf_workflow, hash_function=compilation_hash + ), +) +run_gtfn_cached = otf_compile_executor.OTFBackend( + executor=gtfn_cached_executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), +) + -run_gtfn_gpu = otf_compile_executor.OTFCompileExecutor( +run_gtfn_with_temporaries = otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries", + otf_workflow=gtfn_executor.otf_workflow.replace( + translation=gtfn_executor.otf_workflow.translation.replace( + lift_mode=LiftMode.FORCE_TEMPORARIES + ), + ), + ), + allocator=next_allocators.StandardCPUFieldBufferAllocator(), +) + +gtfn_gpu_executor = otf_compile_executor.OTFCompileExecutor( name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW ) +run_gtfn_gpu = otf_compile_executor.OTFBackend( + executor=gtfn_gpu_executor, + allocator=next_allocators.StandardGPUFieldBufferAllocator(), +) + -run_gtfn_with_temporaries = otf_compile_executor.OTFCompileExecutor( - name="run_gtfn_with_temporaries", - otf_workflow=run_gtfn.otf_workflow.replace( - translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES), +gtfn_gpu_cached_executor = otf_compile_executor.CachedOTFCompileExecutor( + name="run_gtfn_gpu_cached", + otf_workflow=workflow.CachedStep( + step=gtfn_gpu_executor.otf_workflow, hash_function=compilation_hash ), ) +run_gtfn_gpu_cached = otf_compile_executor.OTFBackend( + executor=gtfn_gpu_cached_executor, + allocator=next_allocators.StandardGPUFieldBufferAllocator(), +) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 3560384eb4..f81606eec0 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -21,20 +21,25 @@ from collections.abc import Callable, Iterable from typing import Any, Optional -from gt4py.eve import codegen +import gt4py.eve.codegen as codegen +import gt4py.next.allocators as next_allocators +import gt4py.next.common as common +import gt4py.next.iterator.embedded as embedded +import gt4py.next.iterator.ir as itir +import gt4py.next.iterator.transforms as itir_transforms +import gt4py.next.iterator.transforms.global_tmps as gtmps_transform +import gt4py.next.program_processors.otf_compile_executor as otf_compile_executor +import gt4py.next.program_processors.processor_interface as ppi from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import common -from gt4py.next.iterator import embedded, ir as itir -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries -from gt4py.next.program_processors.processor_interface import program_executor def _create_tmp(axes, origin, shape, dtype): if isinstance(dtype, tuple): return f"({','.join(_create_tmp(axes, origin, shape, dt) for dt in dtype)},)" else: - return f"gtx.np_as_located_field({axes}, origin={origin})(np.empty({shape}, dtype=np.dtype('{dtype}')))" + return ( + f"gtx.as_field([{axes}], np.empty({shape}, dtype=np.dtype('{dtype}')), origin={origin})" + ) class EmbeddedDSL(codegen.TemplatedGenerator): @@ -103,7 +108,7 @@ def visit_Temporary(self, node, **kwargs): def fencil_generator( ir: itir.Node, debug: bool, - lift_mode: LiftMode, + lift_mode: itir_transforms.LiftMode, use_embedded: bool, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], ) -> Callable: @@ -125,7 +130,9 @@ def fencil_generator( if cache_key in _FENCIL_CACHE: return _FENCIL_CACHE[cache_key] - ir = apply_common_transforms(ir, lift_mode=lift_mode, offset_provider=offset_provider) + ir = itir_transforms.apply_common_transforms( + ir, lift_mode=lift_mode, offset_provider=offset_provider + ) program = EmbeddedDSL.apply(ir) @@ -180,8 +187,12 @@ def fencil_generator( if not debug: pathlib.Path(source_file_name).unlink(missing_ok=True) - assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)) - fencil_name = ir.fencil.id + "_wrapper" if isinstance(ir, FencilWithTemporaries) else ir.id + assert isinstance(ir, (itir.FencilDefinition, gtmps_transform.FencilWithTemporaries)) + fencil_name = ( + ir.fencil.id + "_wrapper" + if isinstance(ir, gtmps_transform.FencilWithTemporaries) + else ir.id + ) fencil = getattr(mod, fencil_name) _FENCIL_CACHE[cache_key] = fencil @@ -195,8 +206,8 @@ def execute_roundtrip( column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], debug: bool = False, - lift_mode: LiftMode = LiftMode.FORCE_INLINE, - dispatch_backend: Optional[str] = None, + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, + dispatch_backend: Optional[ppi.ProgramExecutor] = None, ) -> None: fencil = fencil_generator( ir, @@ -216,6 +227,8 @@ def execute_roundtrip( return fencil(*args, **new_kwargs) -@program_executor -def executor(program: itir.FencilDefinition, *args, **kwargs) -> None: - execute_roundtrip(program, *args, **kwargs) +executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type] + +backend = otf_compile_executor.OTFBackend( + executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() +) diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index adc45efaff..061f79f146 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -22,6 +22,7 @@ import operator import numpy as np +import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping @@ -34,11 +35,10 @@ Protocol, Sequence, Tuple, + Type, TypeAlias, TypeGuard, - TypeVar, Union, - cast, ) @@ -48,17 +48,16 @@ cp = None -_ScalarT = TypeVar("_ScalarT", bound=core_defs.Scalar) - - _NDBuffer: TypeAlias = Union[ + # xtyping.Buffer, # TODO: add once we update typing_extensions xtyping.ArrayInterface, xtyping.CUDAArrayInterface, xtyping.DLPackBuffer, ] - -#: Tuple of positive integers encoding a permutation of the dimensions. +#: Tuple of positive integers encoding a permutation of the dimensions, such that +#: layout_map[i] = j means that the i-th dimension of the tensor corresponds +#: to the j-th dimension in the (C-layout) buffer. BufferLayoutMap = NewType("BufferLayoutMap", Sequence[core_defs.PositiveIntegral]) @@ -72,7 +71,7 @@ def is_valid_layout_map(value: Sequence[Any]) -> TypeGuard[BufferLayoutMap]: @dataclasses.dataclass(frozen=True) -class TensorBuffer(Generic[core_defs.NDArrayObjectT, _ScalarT]): +class TensorBuffer(Generic[core_defs.DeviceTypeT, core_defs.ScalarT]): """ N-dimensional (tensor-like) memory buffer. @@ -88,9 +87,9 @@ class TensorBuffer(Generic[core_defs.NDArrayObjectT, _ScalarT]): dtype: Data type descriptor. shape: Tuple with lengths of the corresponding tensor dimensions. strides: Tuple with sizes (in bytes) of the steps in each dimension. - layout_map: Tuple with the order of the dimensions in the buffer. + layout_map: Tuple with the order of the dimensions in the buffer layout_map[i] = j means that the i-th dimension of the tensor - corresponds to the j-th dimension in the buffer. + corresponds to the j-th dimension in the (C-layout) buffer. byte_offset: Offset (in bytes) from the beginning of the buffer to the first valid element. byte_alignment: Alignment (in bytes) of the first valid element. @@ -100,37 +99,45 @@ class TensorBuffer(Generic[core_defs.NDArrayObjectT, _ScalarT]): buffer: _NDBuffer = dataclasses.field(hash=False) memory_address: int - device: core_defs.Device - dtype: core_defs.DType[_ScalarT] + device: core_defs.Device[core_defs.DeviceTypeT] + dtype: core_defs.DType[core_defs.ScalarT] shape: core_defs.TensorShape strides: Tuple[int, ...] layout_map: BufferLayoutMap byte_offset: int byte_alignment: int aligned_index: Tuple[int, ...] - ndarray: core_defs.NDArrayObjectT = dataclasses.field(hash=False) + ndarray: core_defs.NDArrayObject = dataclasses.field(hash=False) @property def ndim(self): """Order of the tensor (`len(tensor_buffer.shape)`).""" return len(self.shape) - def __array__(self, dtype: Optional[np.dtype] = None) -> np.ndarray: - if not hasattr(self.ndarray, "__array__"): + def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> np.ndarray: + if not xtyping.supports_array(self.ndarray): raise TypeError("Cannot export tensor buffer as NumPy array.") - return self.ndarray.__array__(dtype=dtype) # type: ignore[call-overload] # TODO(egparades): figure out the mypy fix + return self.ndarray.__array__(dtype) + + @property + def __array_interface__(self) -> dict[str, Any]: + if not xtyping.supports_array_interface(self.ndarray): + raise TypeError("Cannot export tensor buffer to NumPy array interface.") + + return self.ndarray.__array_interface__ @property - def __cuda_array_interface__(self) -> xtyping.CUDAArrayInterfaceTypedDict: - if not hasattr(self.ndarray, "__cuda_array_interface__"): + def __cuda_array_interface__(self) -> dict[str, Any]: + if not xtyping.supports_cuda_array_interface(self.ndarray): raise TypeError("Cannot export tensor buffer to CUDA array interface.") + return self.ndarray.__cuda_array_interface__ - def __dlpack__(self) -> xtyping.PyCapsule: + def __dlpack__(self, *, stream: Optional[int] = None) -> Any: if not hasattr(self.ndarray, "__dlpack__"): raise TypeError("Cannot export tensor buffer to DLPack.") - return self.ndarray.__dlpack__() + return self.ndarray.__dlpack__(stream=stream) # type: ignore[call-arg,arg-type] # stream is not always supported def __dlpack_device__(self) -> xtyping.DLPackDevice: if not hasattr(self.ndarray, "__dlpack_device__"): @@ -138,32 +145,39 @@ def __dlpack_device__(self) -> xtyping.DLPackDevice: return self.ndarray.__dlpack_device__() -class BufferAllocator(Protocol[core_defs.NDArrayObjectT]): +if TYPE_CHECKING: + # TensorBuffer should be compatible with all the expected buffer interfaces + __TensorBufferAsArrayInterfaceT: Type[xtyping.ArrayInterface] = TensorBuffer + __TensorBufferAsCUDAArrayInterfaceT: Type[xtyping.CUDAArrayInterface] = TensorBuffer + __TensorBufferAsDLPackBufferT: Type[xtyping.DLPackBuffer] = TensorBuffer + + +class BufferAllocator(Protocol[core_defs.DeviceTypeT]): """Protocol for buffer allocators.""" @property - def device_type(self) -> core_defs.DeviceType: + def device_type(self) -> core_defs.DeviceTypeT: ... def allocate( self, shape: Sequence[core_defs.IntegralScalar], - dtype: core_defs.DType[_ScalarT], + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int, layout_map: BufferLayoutMap, - device: core_defs.Device, byte_alignment: int, aligned_index: Optional[Sequence[int]] = None, - ) -> TensorBuffer[core_defs.NDArrayObjectT, _ScalarT]: + ) -> TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: """ Allocate a TensorBuffer with the given shape, layout and alignment settings. Args: - device: Device where the buffer is allocated. - dtype: Data type descriptor. shape: Tensor dimensions. - layout_map: layout of the dimensions in the buffer. - layout_map[i] = j means that the i-th dimension of the tensor - corresponds to the j-th dimension of the buffer. + dtype: Data type descriptor. + layout_map: layout of the dimensions in a buffer with C-layout (contiguous dimension is last). + layout_map[i] = j means that the i-th dimension of the tensor + corresponds to the j-th dimension of the buffer. + device_id: Id of the device of `device_type` where the buffer is allocated. byte_alignment: Alignment (in bytes) of the first valid element. aligned_index: N-dimensional index of the first aligned element. """ @@ -171,18 +185,23 @@ def allocate( @dataclasses.dataclass(frozen=True, init=False) -class _BaseNDArrayBufferAllocator(abc.ABC, Generic[core_defs.NDArrayObjectT]): +class _BaseNDArrayBufferAllocator(abc.ABC, Generic[core_defs.DeviceTypeT]): """Base class for buffer allocators using NumPy-like modules.""" + @property + @abc.abstractmethod + def device_type(self) -> core_defs.DeviceTypeT: + pass + def allocate( self, shape: Sequence[core_defs.IntegralScalar], - dtype: core_defs.DType[_ScalarT], + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int, layout_map: BufferLayoutMap, - device: core_defs.Device, byte_alignment: int, aligned_index: Optional[Sequence[int]] = None, - ) -> TensorBuffer[core_defs.NDArrayObjectT, _ScalarT]: + ) -> TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: if not core_defs.is_valid_tensor_shape(shape): raise ValueError(f"Invalid shape {shape}") ndim = len(shape) @@ -221,7 +240,7 @@ def allocate( strides = tuple(strides_lst) # Allocate total size - buffer = self.raw_alloc(total_length, device) + buffer = self.malloc(total_length, device_id) memory_address = self.array_ns.byte_bounds(buffer)[0] # Compute final byte offset to align the requested buffer index @@ -247,7 +266,7 @@ def allocate( buffer, dtype, shape, padded_shape, item_size, strides, byte_offset ) - if device.device_type == core_defs.DeviceType.ROCM: + if self.device_type == core_defs.DeviceType.ROCM: # until we can rely on dlpack ndarray.__hip_array_interface__ = { # type: ignore[attr-defined] "shape": ndarray.shape, # type: ignore[union-attr] @@ -262,7 +281,7 @@ def allocate( return TensorBuffer( buffer=buffer, memory_address=memory_address, - device=device, + device=core_defs.Device(self.device_type, device_id), dtype=dtype, shape=shape, strides=strides, @@ -275,77 +294,101 @@ def allocate( @property @abc.abstractmethod - def array_ns(self) -> _NumPyLikeNamespace[core_defs.NDArrayObjectT]: + def array_ns(self) -> ValidNumPyLikeAllocationNS: pass @abc.abstractmethod - def raw_alloc(self, length: int, device: core_defs.Device) -> _NDBuffer: + def malloc(self, length: int, device_id: int) -> _NDBuffer: pass @abc.abstractmethod def tensorize( self, buffer: _NDBuffer, - dtype: core_defs.DType[_ScalarT], + dtype: core_defs.DType[core_defs.ScalarT], shape: core_defs.TensorShape, allocated_shape: core_defs.TensorShape, item_size: int, strides: Sequence[int], byte_offset: int, - ) -> core_defs.NDArrayObjectT: + ) -> core_defs.NDArrayObject: pass -if TYPE_CHECKING: +class ValidNumPyLikeAllocationNS(Protocol): + class _NumPyLibModule(Protocol): + class _NumPyLibStridesModule(Protocol): + @staticmethod + def as_strided( + ndarray: core_defs.NDArrayObject, **kwargs: Any + ) -> core_defs.NDArrayObject: + ... + + stride_tricks: _NumPyLibStridesModule + + lib: _NumPyLibModule + + @staticmethod + def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: + ... - class _NumPyLikeNamespace(Protocol[core_defs.NDArrayObjectT]): - class _NumPyLibModule(Protocol): - class _NumPyLibStridesModule(Protocol): - def as_strided( - self, ndarray: core_defs.NDArrayObjectT, **kwargs: Any - ) -> core_defs.NDArrayObjectT: - ... + @staticmethod + def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: + ... - stride_tricks: _NumPyLibStridesModule - lib: _NumPyLibModule +def is_valid_nplike_allocation_ns(obj: Any) -> TypeGuard[ValidNumPyLikeAllocationNS]: + return ( + len(required_keys := {"empty", "byte_bounds", "lib"} & set(dir(np))) == len(required_keys) + and "stride_tricks" in dir(np.lib) + and "as_strided" in dir(np.lib.stride_tricks) + ) - def empty(self, shape: core_defs.TensorShape, dtype: np.dtype) -> core_defs.NDArrayObjectT: - ... - def byte_bounds(self, ndarray: _NDBuffer) -> tuple[int, int]: - ... +if not TYPE_CHECKING: + is_valid_nplike_allocation_ns = functools.lru_cache(maxsize=None)(is_valid_nplike_allocation_ns) -@dataclasses.dataclass(frozen=True) -class NumPyLikeArrayBufferAllocator(_BaseNDArrayBufferAllocator[core_defs.NDArrayObjectT]): - device_type: core_defs.DeviceType - array_ns_ref: _NumPyLikeNamespace[core_defs.NDArrayObjectT] +@dataclasses.dataclass(frozen=True, init=False) +class NDArrayBufferAllocator(_BaseNDArrayBufferAllocator[core_defs.DeviceTypeT]): + _device_type: core_defs.DeviceTypeT + _array_ns: ValidNumPyLikeAllocationNS + + def __init__( + self, + device_type: core_defs.DeviceTypeT, + array_ns: ValidNumPyLikeAllocationNS, + ): + object.__setattr__(self, "_device_type", device_type) + object.__setattr__(self, "_array_ns", array_ns) + + @property + def device_type(self) -> core_defs.DeviceTypeT: + return self._device_type @property - def array_ns(self) -> _NumPyLikeNamespace[core_defs.NDArrayObjectT]: - return self.array_ns_ref + def array_ns(self) -> ValidNumPyLikeAllocationNS: + return self._array_ns - def raw_alloc(self, length: int, device: core_defs.Device) -> _NDBuffer: - if device.device_type != core_defs.DeviceType.CPU and device.device_id != 0: - raise ValueError(f"Unsupported device {device} for memory allocation") + def malloc(self, length: int, device_id: int) -> _NDBuffer: + if self.device_type == core_defs.DeviceType.CPU and device_id != 0: + raise ValueError(f"Unsupported device ID {device_id} for CPU memory allocation") shape = (length,) assert core_defs.is_valid_tensor_shape(shape) # for mypy - return cast( - _NDBuffer, self.array_ns.empty(shape=shape, dtype=np.dtype(np.uint8)) - ) # TODO(havogt): figure out how we type this properly + out = self.array_ns.empty(shape=tuple(shape), dtype=np.dtype(np.uint8)) + return out def tensorize( self, buffer: _NDBuffer, - dtype: core_defs.DType[_ScalarT], + dtype: core_defs.DType[core_defs.ScalarT], shape: core_defs.TensorShape, allocated_shape: core_defs.TensorShape, item_size: int, strides: Sequence[int], byte_offset: int, - ) -> core_defs.NDArrayObjectT: + ) -> core_defs.NDArrayObject: aligned_buffer = buffer[byte_offset : byte_offset + math.prod(allocated_shape) * item_size] # type: ignore[index] # TODO(egparedes): should we extend `_NDBuffer`s to cover __getitem__? flat_ndarray = aligned_buffer.view(dtype=np.dtype(dtype)) tensor_view = self.array_ns.lib.stride_tricks.as_strided( @@ -356,53 +399,3 @@ def tensorize( tensor_view = tensor_view[shape_slices] return tensor_view - - -#: Registry of allocators for each device type. -device_allocators: dict[core_defs.DeviceType, BufferAllocator] = {} - -device_allocators[core_defs.DeviceType.CPU] = NumPyLikeArrayBufferAllocator( - device_type=core_defs.DeviceType.CPU, - array_ns_ref=cast(_NumPyLikeNamespace, np) if TYPE_CHECKING else np, -) - -if cp: - device_allocators[core_defs.DeviceType.CUDA] = NumPyLikeArrayBufferAllocator( - device_type=core_defs.DeviceType.CUDA, - array_ns_ref=cp, - ) - device_allocators[core_defs.DeviceType.ROCM] = NumPyLikeArrayBufferAllocator( - device_type=core_defs.DeviceType.ROCM, - array_ns_ref=cp, - ) - - -def allocate( - shape: Sequence[core_defs.IntegralScalar], - dtype: core_defs.DType[_ScalarT], - layout_map: BufferLayoutMap, - *, - byte_alignment: int, - aligned_index: Optional[Sequence[int]] = None, - device: Optional[core_defs.Device] = None, - allocator: Optional[BufferAllocator] = None, -) -> TensorBuffer: - """Allocate a TensorBuffer with the given settings on the given device.""" - if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified") - if device is None: - assert allocator is not None # for mypy - device = core_defs.Device(allocator.device_type, 0) - assert device is not None # for mypy - allocator = allocator or device_allocators[device.device_type] - if device.device_type != allocator.device_type: - raise ValueError(f"Device {device} and allocator {allocator} are incompatible") - - return allocator.allocate( - shape=shape, - dtype=dtype, - layout_map=layout_map, - byte_alignment=byte_alignment, - aligned_index=aligned_index, - device=device, - ) diff --git a/src/gt4py/storage/cartesian/interface.py b/src/gt4py/storage/cartesian/interface.py index 6e19b9d771..517593dd38 100644 --- a/src/gt4py/storage/cartesian/interface.py +++ b/src/gt4py/storage/cartesian/interface.py @@ -15,7 +15,7 @@ from __future__ import annotations import numbers -from typing import Any, Optional, Protocol, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import numpy as np @@ -33,20 +33,7 @@ except ImportError: dace = None -if np.lib.NumpyVersion(np.__version__) >= "1.20.0": - from numpy.typing import ArrayLike, DTypeLike -else: - ArrayLike = Any # type: ignore[misc] # assign multiple types in both branches - DTypeLike = Any # type: ignore[misc] # assign multiple types in both branches - - -# Protocols -class GTDimsInterface(Protocol): - __gt_dims__: Tuple[str, ...] - - -class GTOriginInterface(Protocol): - __gt_origin__: Tuple[int, ...] +from numpy.typing import ArrayLike, DTypeLike # Helper functions diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index e6060328ff..0f7cf5d0ab 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -17,7 +17,7 @@ import collections.abc import math import numbers -from typing import Any, Literal, Optional, Sequence, Tuple, Union, cast +from typing import Any, Final, Literal, Optional, Sequence, Tuple, Union, cast import numpy as np import numpy.typing as npt @@ -39,8 +39,34 @@ cp = None +CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( + None + if not cp + else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA) +) + + FieldLike = Union["cp.ndarray", np.ndarray, ArrayInterface, CUDAArrayInterface] +assert allocators.is_valid_nplike_allocation_ns(np) + +_CPUBufferAllocator = allocators.NDArrayBufferAllocator( + device_type=core_defs.DeviceType.CPU, + array_ns=np, +) + +_GPUBufferAllocator: Optional[allocators.NDArrayBufferAllocator] = None +if cp: + assert allocators.is_valid_nplike_allocation_ns(cp) + if CUPY_DEVICE == core_defs.DeviceType.CUDA: + _GPUBufferAllocator = allocators.NDArrayBufferAllocator( + device_type=core_defs.DeviceType.CUDA, array_ns=cp + ) + else: + _GPUBufferAllocator = allocators.NDArrayBufferAllocator( + device_type=core_defs.DeviceType.ROCM, array_ns=cp + ) + def _idx_from_order(order): return list(np.argsort(order)) @@ -201,15 +227,15 @@ def allocate_cpu( aligned_index: Optional[Sequence[int]], ) -> Tuple[allocators._NDBuffer, np.ndarray]: device = core_defs.Device(core_defs.DeviceType.CPU, 0) - buffer = allocators.allocate( + buffer = _CPUBufferAllocator.allocate( shape, core_defs.dtype(dtype), + device_id=device.device_id, layout_map=layout_map, - device=device, byte_alignment=alignment_bytes, aligned_index=aligned_index, ) - return buffer.buffer, buffer.ndarray + return buffer.buffer, cast(np.ndarray, buffer.ndarray) def allocate_gpu( @@ -219,15 +245,16 @@ def allocate_gpu( alignment_bytes: int, aligned_index: Optional[Sequence[int]], ) -> Tuple["cp.ndarray", "cp.ndarray"]: - device = core_defs.Device( + assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" + device = core_defs.Device( # type: ignore[type-var] core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA, 0 ) - buffer = allocators.allocate( + buffer = _GPUBufferAllocator.allocate( shape, core_defs.dtype(dtype), + device_id=device.device_id, layout_map=layout_map, - device=device, byte_alignment=alignment_bytes, aligned_index=aligned_index, ) - return buffer.buffer, buffer.ndarray + return buffer.buffer, cast("cp.ndarray", buffer.ndarray) diff --git a/tests/eve_tests/unit_tests/test_extended_typing.py b/tests/eve_tests/unit_tests/test_extended_typing.py index da3cbbaeda..733e12577c 100644 --- a/tests/eve_tests/unit_tests/test_extended_typing.py +++ b/tests/eve_tests/unit_tests/test_extended_typing.py @@ -232,6 +232,69 @@ def test_subclass_check_with_data_members(self, sample_class_defs): assert issubclass(ConcreteClass, NoDataProto) +def test_supports_array_interface(): + from gt4py.eve.extended_typing import supports_array_interface + + class ArrayInterface: + __array_interface__ = "interface" + + class NoArrayInterface: + pass + + assert supports_array_interface(ArrayInterface()) + assert not supports_array_interface(NoArrayInterface()) + assert not supports_array_interface("array") + assert not supports_array_interface(None) + + +def test_supports_cuda_array_interface(): + from gt4py.eve.extended_typing import supports_cuda_array_interface + + class CudaArray: + def __cuda_array_interface__(self): + return {} + + class NoCudaArray: + pass + + assert supports_cuda_array_interface(CudaArray()) + assert not supports_cuda_array_interface(NoCudaArray()) + assert not supports_cuda_array_interface("cuda") + assert not supports_cuda_array_interface(None) + + +def test_supports_dlpack(): + from gt4py.eve.extended_typing import supports_dlpack + + class DummyDLPackBuffer: + def __dlpack__(self): + pass + + def __dlpack_device__(self): + pass + + class DLPackBufferWithWrongBufferMethod: + __dlpack__ = "buffer" + + def __dlpack_device__(self): + pass + + class DLPackBufferWithoutDevice: + def __dlpack__(self): + pass + + class DLPackBufferWithWrongDevice: + def __dlpack__(self): + pass + + __dlpack_device__ = "device" + + assert supports_dlpack(DummyDLPackBuffer()) + assert not supports_dlpack(DLPackBufferWithWrongBufferMethod()) + assert not supports_dlpack(DLPackBufferWithoutDevice()) + assert not supports_dlpack(DLPackBufferWithWrongDevice()) + + @pytest.mark.parametrize("t", (int, float, dict, tuple, frozenset, collections.abc.Mapping)) def test_is_actual_valid_type(t): assert xtyping.is_actual_type(t) diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index fda69d75d9..99513ba175 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -137,6 +137,21 @@ def unique_data_items(request): ] +def test_fluid_partial(): + from gt4py.eve.utils import fluid_partial + + def func(a, b, c): + return a + b + c + + fp1 = fluid_partial(func, 1) + fp2 = fp1.partial(2) + fp3 = fp2.partial(3) + + assert fp1(2, 3) == 6 + assert fp2(3) == 6 + assert fp3() == 6 + + def test_noninstantiable_class(): @eve.utils.noninstantiable class NonInstantiableClass(eve.datamodels.DataModel): diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index 54bc4d9c69..e2905ab49a 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -23,4 +23,8 @@ def get_processor_id(processor): module_path = processor.__module__.split(".")[-1] name = processor.__name__ return f"{module_path}.{name}" + elif hasattr(processor, "__module__") and hasattr(processor, "__class__"): + module_path = processor.__module__.split(".")[-1] + name = processor.__class__.__name__ + return f"{module_path}.{name}" return repr(processor) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 98ac9352c3..ddea04649f 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -11,21 +11,73 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import pytest - """Contains definition of test-exclusion matrices, see ADR 15.""" +import enum +import importlib + +import pytest + + # Skip definitions XFAIL = pytest.xfail SKIP = pytest.skip -# Processor ids as returned by next_tests.get_processor_id() -DACE = "dace_iterator.run_dace_iterator" -GTFN_CPU = "otf_compile_executor.run_gtfn" -GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" -GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" -GTFN_FORMAT_SOURCECODE = "gtfn.format_sourcecode" + +# Program processors +class _PythonObjectIdMixin: + # Only useful for classes inheriting from (str, enum.Enum) + def __str__(self) -> str: + assert isinstance(self.value, str) + return self.value + + def load(self) -> object: + *mods, obj = self.value.split(".") + globs = {"_m": importlib.import_module(".".join(mods))} + obj = eval(f"_m.{obj}", globs) + return obj + + __invert__ = load + + def short_id(self, num_components: int = 2) -> str: + return ".".join(self.value.split(".")[-num_components:]) + + +class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): + GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn" + GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative" + GTFN_CPU_WITH_TEMPORARIES = ( + "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" + ) + ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend" + DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" + + +class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): + DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu" + + +class ProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum): + GTFN_CPU_EXECUTOR = f"{ProgramBackendId.GTFN_CPU}.executor" + GTFN_CPU_IMPERATIVE_EXECUTOR = f"{ProgramBackendId.GTFN_CPU_IMPERATIVE}.executor" + GTFN_CPU_WITH_TEMPORARIES = f"{ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES}.executor" + ROUNDTRIP = f"{ProgramBackendId.ROUNDTRIP}.executor" + DOUBLE_ROUNDTRIP = f"{ProgramBackendId.DOUBLE_ROUNDTRIP}.executor" + + +class OptionalProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum): + DACE_CPU_EXECUTOR = f"{OptionalProgramBackendId.DACE_CPU}.executor" + + +class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + GTFN_CPP_FORMATTER = "gt4py.next.program_processors.formatters.gtfn.format_cpp" + ITIR_PRETTY_PRINTER = ( + "gt4py.next.program_processors.formatters.pretty_print.format_itir_and_check" + ) + ITIR_TYPE_CHECKER = "gt4py.next.program_processors.formatters.type_check.check_type_inference" + LISP_FORMATTER = "gt4py.next.program_processors.formatters.lisp.format_lisp" + # Test markers REQUIRES_ATLAS = "requires_atlas" @@ -66,7 +118,7 @@ #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) BACKEND_SKIP_TEST_MATRIX = { - DACE: GTFN_SKIP_TEST_LIST + OptionalProgramBackendId.DACE_CPU: GTFN_SKIP_TEST_LIST + [ (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -80,20 +132,20 @@ (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ], - GTFN_CPU: GTFN_SKIP_TEST_LIST + ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], - GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], - GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], - GTFN_FORMAT_SOURCECODE: [ + ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], } diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index ee0074e65f..634d85e64c 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -27,7 +27,7 @@ import gt4py.next as gtx from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common +from gt4py.next import common, constructors from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -129,12 +129,15 @@ def scalar_value(self) -> ScalarValue: def field( self, - backend: ppi.ProgramProcessor, + backend: ppi.ProgramExecutor, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, ) -> FieldValue: - return gtx.np_as_located_field(*sizes.keys())( - np.full(tuple(sizes.values()), self.value, dtype=dtype) + return constructors.full( + domain=common.domain(sizes), + fill_value=self.value, + dtype=dtype, + allocator=backend, ) @@ -155,7 +158,7 @@ def scalar_value(self) -> ScalarValue: def field( self, - backend: ppi.ProgramProcessor, + backend: ppi.ProgramExecutor, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, ) -> FieldValue: @@ -164,7 +167,9 @@ def field( f"`IndexInitializer` only supports fields with a single `Dimension`, got {sizes}." ) n_data = list(sizes.values())[0] - return gtx.np_as_located_field(*sizes.keys())(np.arange(0, n_data, dtype=dtype)) + return constructors.as_field( + domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=backend + ) def from_case( self: Self, @@ -202,8 +207,10 @@ def field( svals = tuple(sizes.values()) n_data = int(np.prod(svals)) self.start += n_data - return gtx.np_as_located_field(*sizes.keys())( - np.arange(start, start + n_data, dtype=dtype).reshape(svals) + return constructors.as_field( + common.domain(sizes), + np.arange(start, start + n_data, dtype=dtype).reshape(svals), + allocator=backend, ) def from_case( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 93296ae85f..386e64451d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -34,6 +34,7 @@ raise e import next_tests +import next_tests.exclusion_matrices as definitions def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: @@ -43,18 +44,18 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non OPTIONAL_PROCESSORS = [] if dace_iterator: - OPTIONAL_PROCESSORS.append(dace_iterator.run_dace_iterator) + OPTIONAL_PROCESSORS.append(definitions.OptionalProgramBackendId.DACE_CPU) @pytest.fixture( params=[ - roundtrip.executor, - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + definitions.ProgramBackendId.ROUNDTRIP, + definitions.ProgramBackendId.GTFN_CPU, + definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, + definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, ] + OPTIONAL_PROCESSORS, - ids=lambda p: next_tests.get_processor_id(p), + ids=lambda p: p.short_id() if p is not None else "None", ) def fieldview_backend(request): """ @@ -63,16 +64,20 @@ def fieldview_backend(request): Notes: Check ADR 15 for details on the test-exclusion matrices. """ - backend = request.param - backend_id = next_tests.get_processor_id(backend) + backend_id = request.param + if backend_id is None: + backend = None + else: + backend = backend_id.load() + + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) - for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( - backend_id, [] - ): - if request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=backend_id)) + backup_backend = decorator.DEFAULT_BACKEND - backup_backend = decorator.DEFAULT_BACKEND decorator.DEFAULT_BACKEND = no_backend yield backend decorator.DEFAULT_BACKEND = backup_backend @@ -203,8 +208,8 @@ def reduction_setup(): C2V=gtx.FieldOffset("C2V", source=Vertex, target=(Cell, c2vdim)), C2E=gtx.FieldOffset("C2E", source=Edge, target=(Cell, c2edim)), # inp=gtx.index_field(edge, dtype=np.int64), # TODO enable once we support gtx.index_fields in bindings - inp=gtx.np_as_located_field(Edge)(np.arange(num_edges, dtype=np.int32)), - out=gtx.np_as_located_field(Vertex)(np.zeros([num_vertices], dtype=np.int32)), + inp=gtx.as_field([Edge], np.arange(num_edges, dtype=np.int32)), + out=gtx.as_field([Vertex], np.zeros([num_vertices], dtype=np.int32)), offset_provider={ "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f974e07ad8..d381a2242a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -506,7 +506,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1) - out = gtx.np_as_located_field(KDim)(np.zeros((cartesian_case.default_sizes[KDim],))) + out = gtx.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) if not forward: expected = np.flip(expected) @@ -637,8 +637,8 @@ def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 k_size = cartesian_case.default_sizes[KDim] - a = gtx.np_as_located_field(KDim)(4.0 * np.ones((k_size,))) - out = gtx.np_as_located_field(KDim)(np.zeros((k_size,))) + a = gtx.as_field([KDim], 4.0 * np.ones((k_size,))) + out = gtx.as_field([KDim], np.zeros((k_size,))) cases.verify( cartesian_case, @@ -685,9 +685,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] - inp1 = gtx.np_as_located_field(KDim)(np.ones((k_size,))) - inp2 = gtx.np_as_located_field(KDim)(np.arange(0.0, k_size, 1)) - out = gtx.np_as_located_field(KDim)(np.zeros((k_size,))) + inp1 = gtx.as_field([KDim], np.ones((k_size,))) + inp2 = gtx.as_field([KDim], np.arange(0.0, k_size, 1)) + out = gtx.as_field([KDim], np.zeros((k_size,))) def prev_levels_iterator(i): return range(i + 1) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index dbc35ddfdf..04b27c6c17 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -35,7 +35,7 @@ def testee( inp * ones(V2E), axis=V2EDim ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported - inp = gtx.np_as_located_field(Vertex, V2EDim)(unstructured_case.offset_provider["V2E"].table) + inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() cases.verify( @@ -56,7 +56,7 @@ def test_external_local_field_only(unstructured_case): def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32]: return neighbor_sum(inp, axis=V2EDim) - inp = gtx.np_as_located_field(Vertex, V2EDim)(unstructured_case.offset_provider["V2E"].table) + inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) cases.verify( unstructured_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py index 381cc740c5..80e9a8e07a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py @@ -15,7 +15,6 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator import embedded from gt4py.next.program_processors.runners import dace_iterator, gtfn from next_tests.integration_tests import cases @@ -26,9 +25,9 @@ @pytest.mark.requires_gpu -@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace, gtfn.run_gtfn_gpu]) +@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace_gpu, gtfn.run_gtfn_gpu]) def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures - import cupy as cp # TODO(ricoh): replace with storages solution when available + import cupy as cp @gtx.field_operator(backend=fieldview_backend) def testee(a: cases.IJKField) -> cases.IJKField: @@ -36,8 +35,17 @@ def testee(a: cases.IJKField) -> cases.IJKField: inp_arr = cp.full(shape=(3, 4, 5), fill_value=3, dtype=cp.int32) outp_arr = cp.zeros_like(inp_arr) - inp = embedded.np_as_located_field(cases.IDim, cases.JDim, cases.KDim)(inp_arr) - outp = embedded.np_as_located_field(cases.IDim, cases.JDim, cases.KDim)(outp_arr) + inp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], inp_arr) + outp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], outp_arr) testee(inp, out=outp, offset_provider={}) assert cp.allclose(inp_arr, outp_arr) + + inp_field = gtx.full( + [cases.IDim, cases.JDim, cases.KDim], fill_value=3, allocator=fieldview_backend + ) + out_field = gtx.zeros( + [cases.IDim, cases.JDim, cases.KDim], outp_arr, allocator=fieldview_backend + ) + testee(inp_field, out=out_field, offset_provider={}) + assert cp.allclose(inp_field.ndarray, out_field.ndarray) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 56d5e35b3a..8213f54a45 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -267,7 +267,7 @@ def conditional_program( conditional_shifted(mask, a, b, out=out) size = cartesian_case.default_sizes[IDim] + 1 - mask = gtx.np_as_located_field(IDim)(np.random.choice(a=[False, True], size=(size))) + mask = gtx.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_program, "a").extend({IDim: (0, 1)})() b = cases.allocate(cartesian_case, conditional_program, "b").extend({IDim: (0, 1)})() out = cases.allocate(cartesian_case, conditional_shifted, cases.RETURN)() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index f7121dc82f..a5d2b92719 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from gt4py.next import np_as_located_field +import gt4py.next as gtx from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.decorator import FieldOperator from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction @@ -122,9 +122,9 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp else: ref_impl: Callable = getattr(np, builtin_name) - inps = [np_as_located_field(IDim)(np.asarray(input)) for input in inputs] + inps = [gtx.as_field([IDim], np.asarray(input)) for input in inputs] expected = ref_impl(*inputs) - out = np_as_located_field(IDim)(np.zeros_like(expected)) + out = gtx.as_field([IDim], np.zeros_like(expected)) builtin_field_op = make_builtin_field_operator(builtin_name).with_backend( cartesian_case.backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 034ce56fee..5a277f9440 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -90,7 +90,7 @@ def test_mod(cartesian_case): def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 - inp1 = gtx.np_as_located_field(IDim)(np.asarray(range(10), dtype=int32) - 5) + inp1 = gtx.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index d86bc21679..7a1c827a0d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -214,7 +214,7 @@ def prog( def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) - inp = gtx.np_as_located_field(JDim)(np.ones((cartesian_case.default_sizes[JDim],))) + inp = gtx.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() with pytest.raises(TypeError) as exc_info: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index f9fd2c1353..e9c3ac8d19 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -18,7 +18,8 @@ import numpy as np import pytest -from gt4py.next import Field, errors, field_operator, float64, index_field, np_as_located_field +import gt4py.next as gtx +from gt4py.next import Field, errors, field_operator, float64, index_field from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index e2bbbaa553..d5d57c9024 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -20,6 +20,7 @@ import pytest import gt4py.next as gtx +import gt4py.next.program_processors.processor_interface as ppi from gt4py.next.iterator import builtins as it_builtins from gt4py.next.iterator.builtins import ( and_, @@ -58,7 +59,7 @@ from next_tests.unit_tests.conftest import program_processor, run_processor -def asarray(*lists): +def array_maker(*lists): def _listify(val): if isinstance(val, Iterable): return val @@ -72,8 +73,8 @@ def _listify(val): IDim = gtx.Dimension("IDim") -def asfield(*arrays): - res = list(map(gtx.np_as_located_field(IDim), arrays)) +def field_maker(*arrays): + res = list(map(gtx.as_field.partial([IDim]), arrays)) return res @@ -171,8 +172,8 @@ def arithmetic_and_logical_test_data(): def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor - inps = asfield(*asarray(*inputs)) - out = asfield((np.zeros_like(*asarray(expected))))[0] + inps = field_maker(*array_maker(*inputs)) + out = field_maker((np.zeros_like(*array_maker(expected))))[0] fencil(builtin, out, *inps, processor=program_processor, as_column=as_column) @@ -184,13 +185,16 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): if builtin == if_: pytest.skip("If cannot be used unapplied") - inps = asfield(*asarray(*inputs)) - out = asfield((np.zeros_like(*asarray(expected))))[0] + inps = field_maker(*array_maker(*inputs)) + out = field_maker((np.zeros_like(*array_maker(expected))))[0] + gtfn_executor = run_gtfn.executor gtfn_without_transforms = dataclasses.replace( - run_gtfn, - otf_workflow=run_gtfn.otf_workflow.replace( - translation=run_gtfn.otf_workflow.translation.replace(enable_itir_transforms=False), + gtfn_executor, + otf_workflow=gtfn_executor.otf_workflow.replace( + translation=gtfn_executor.otf_workflow.translation.replace( + enable_itir_transforms=False + ), ), ) # avoid inlining the function fencil(builtin, out, *inps, processor=gtfn_without_transforms) @@ -202,6 +206,7 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor + # validate = ppi.is_program_backend(program_processor) if builtin_name == "gamma": # numpy has no gamma function @@ -209,10 +214,10 @@ def test_math_function_builtins(program_processor, builtin_name, inputs, as_colu else: ref_impl: Callable = getattr(np, builtin_name) - inps = asfield(*asarray(*inputs)) + inps = field_maker(*array_maker(*inputs)) expected = ref_impl(*inputs) - out = asfield((np.zeros_like(*asarray(expected))))[0] + out = field_maker((np.zeros_like(*array_maker(expected))))[0] fencil( getattr(it_builtins, builtin_name), @@ -251,8 +256,8 @@ def test_can_deref(program_processor, stencil): Node = gtx.Dimension("Node") - inp = gtx.np_as_located_field(Node)(np.ones((1,), dtype=np.int32)) - out = gtx.np_as_located_field(Node)(np.asarray([0], dtype=inp.dtype)) + inp = gtx.as_field([Node], np.ones((1,), dtype=np.int32)) + out = gtx.as_field([Node], np.asarray([0], dtype=inp.dtype)) no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[-1]]), Node, Node, 1) run_processor( @@ -290,8 +295,8 @@ def test_can_deref(program_processor, stencil): # shifted = shift(Neighbor, 0)(inp) # return if_(can_deref(shifted), 1, -1) -# inp = gtx.np_as_located_field(Node)(np.zeros((1,))) -# out = gtx.np_as_located_field(Node)(np.asarray([0])) +# inp = gtx.as_field([Node], np.zeros((1,))) +# out = gtx.as_field([Node], np.asarray([0])) # no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[None]]), Node, Node, 1) # _can_deref[{Node: range(1)}]( @@ -324,7 +329,7 @@ def test_cast(program_processor, as_column, input_value, dtype, np_dtype): program_processor, validate = program_processor column_axis = IDim if as_column else None - inp = asfield(np.array([input_value]))[0] + inp = field_maker(np.array([input_value]))[0] casted_valued = np_dtype(input_value) @@ -332,7 +337,7 @@ def test_cast(program_processor, as_column, input_value, dtype, np_dtype): def sten_cast(it, casted_valued): return eq(cast_(deref(it), dtype), deref(casted_valued)) - out = asfield(np.zeros_like(inp, dtype=builtins.bool))[0] + out = field_maker(np.zeros_like(inp, dtype=builtins.bool))[0] run_processor( sten_cast[{IDim: range(1)}], program_processor, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 05a7d4d9df..5c80d9e415 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -52,8 +52,8 @@ def fencil_swapped(output, input): def test_cartesian_offset_provider(): - inp = gtx.np_as_located_field(I_loc, J_loc)(np.asarray([[0, 42], [1, 43]])) - out = gtx.np_as_located_field(I_loc, J_loc)(np.asarray([[-1]])) + inp = gtx.as_field([I_loc, J_loc], np.asarray([[0, 42], [1, 43]])) + out = gtx.as_field([I_loc, J_loc], np.asarray([[-1]])) fencil(out, inp) assert out[0][0] == 42 diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index c2517f1a07..de7ebf2869 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -37,8 +37,8 @@ def test_conditional_w_tuple(program_processor): shape = [5] - inp = gtx.np_as_located_field(IDim)(np.random.randint(0, 2, shape, dtype=np.int32)) - out = gtx.np_as_located_field(IDim)(np.zeros(shape)) + inp = gtx.as_field([IDim], np.random.randint(0, 2, shape, dtype=np.int32)) + out = gtx.as_field([IDim], np.zeros(shape)) dom = { IDim: range(0, shape[0]), diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py index c2d7ed5e59..83a86319b4 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py @@ -30,8 +30,8 @@ def constant_stencil(): # this is traced as a lambda, TODO directly feed iterat return deref(inp) + deref(lift(constant_stencil)()) - inp = gtx.np_as_located_field(IDim)(np.asarray([0, 42], dtype=np.int32)) - res = gtx.np_as_located_field(IDim)(np.zeros_like(inp)) + inp = gtx.as_field([IDim], np.asarray([0, 42], dtype=np.int32)) + res = gtx.as_field([IDim], np.zeros_like(inp)) add_constant[{IDim: range(2)}](inp, out=res, offset_provider={}, backend=roundtrip.executor) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index 75b935677b..f9bd2cc33b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -31,9 +31,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import fundef, offset from gt4py.next.program_processors.formatters import type_check -from gt4py.next.program_processors.formatters.gtfn import ( - format_sourcecode as gtfn_format_sourcecode, -) +from gt4py.next.program_processors.formatters.gtfn import format_cpp as gtfn_format_sourcecode from next_tests.integration_tests.cases import IDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -57,7 +55,7 @@ def test_simple_indirection(program_processor): program_processor, validate = program_processor if program_processor in [ - type_check.check, + type_check.check_type_inference, gtfn_format_sourcecode, ]: pytest.xfail( @@ -65,10 +63,10 @@ def test_simple_indirection(program_processor): ) # TODO fix test or generalize itir? shape = [8] - inp = gtx.np_as_located_field(IDim, origin={IDim: 1})(np.arange(0, shape[0] + 2)) + inp = gtx.as_field([IDim], np.arange(0, shape[0] + 2), origin={IDim: 1}) rng = np.random.default_rng() - cond = gtx.np_as_located_field(IDim)(rng.normal(size=shape)) - out = gtx.np_as_located_field(IDim)(np.zeros(shape, dtype=inp.dtype)) + cond = gtx.as_field([IDim], rng.normal(size=shape)) + out = gtx.as_field([IDim], np.zeros(shape, dtype=inp.dtype)) ref = np.zeros(shape, dtype=inp.dtype) for i in range(shape[0]): @@ -97,9 +95,9 @@ def test_direct_offset_for_indirection(program_processor): program_processor, validate = program_processor shape = [4] - inp = gtx.np_as_located_field(IDim)(np.asarray(range(shape[0]), dtype=np.float64)) - cond = gtx.np_as_located_field(IDim)(np.asarray([2, 1, -1, -2], dtype=np.int32)) - out = gtx.np_as_located_field(IDim)(np.zeros(shape, dtype=np.float64)) + inp = gtx.as_field([IDim], np.asarray(range(shape[0]), dtype=np.float64)) + cond = gtx.as_field([IDim], np.asarray([2, 1, -1, -2], dtype=np.int32)) + out = gtx.as_field([IDim], np.zeros(shape, dtype=np.float64)) ref = np.zeros(shape) for i in range(shape[0]): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index d0dc8ec475..2df7691f9e 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -33,11 +33,11 @@ def dom(): def a_field(): - return gtx.np_as_located_field(I)(np.arange(0, _isize, dtype=np.float64)) + return gtx.as_field([I], np.arange(0, _isize, dtype=np.float64)) def out_field(): - return gtx.np_as_located_field(I)(np.zeros(shape=(_isize,))) + return gtx.as_field([I], np.zeros(shape=(_isize,))) @fundef diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index e02dab0a72..3af0440c27 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -30,10 +30,11 @@ def test_scan_in_stencil(program_processor, lift_mode): isize = 1 ksize = 3 Koff = offset("Koff") - inp = gtx.np_as_located_field(IDim, KDim)( - np.copy(np.broadcast_to(np.arange(0, ksize, dtype=np.float64), (isize, ksize))) + inp = gtx.as_field( + [IDim, KDim], + np.copy(np.broadcast_to(np.arange(0, ksize, dtype=np.float64), (isize, ksize))), ) - out = gtx.np_as_located_field(IDim, KDim)(np.zeros((isize, ksize))) + out = gtx.as_field([IDim, KDim], np.zeros((isize, ksize))) reference = np.zeros((isize, ksize - 1)) reference[:, 0] = inp.ndarray[:, 0] + inp.ndarray[:, 1] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 0ac38e9b9f..abdfffd74e 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -56,12 +56,13 @@ def test_strided_offset_provider(program_processor): LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() - inp = gtx.np_as_located_field(LocAB)( + inp = gtx.as_field( + [LocAB], rng.normal( size=(LocAB_size,), - ) + ), ) - out = gtx.np_as_located_field(LocA)(np.zeros((LocA_size,))) + out = gtx.as_field([LocA], np.zeros((LocA_size,))) ref = np.sum(np.asarray(inp).reshape(LocA_size, max_neighbors), axis=-1) run_processor(fencil, program_processor, LocA_size, out, inp) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index cc12183a24..8c59f994ee 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -52,8 +52,8 @@ def test_trivial(program_processor, lift_mode): out = np.copy(inp) shape = (out.shape[0], out.shape[1]) - inp_s = gtx.np_as_located_field(IDim, JDim, origin={IDim: 0, JDim: 0})(inp[:, :, 0]) - out_s = gtx.np_as_located_field(IDim, JDim)(np.zeros_like(inp[:, :, 0])) + inp_s = gtx.as_field([IDim, JDim], inp[:, :, 0], origin={IDim: 0, JDim: 0}) + out_s = gtx.as_field([IDim, JDim], np.zeros_like(inp[:, :, 0])) run_processor( baz[cartesian_domain(named_range(IDim, 0, shape[0]), named_range(JDim, 0, shape[1]))], @@ -85,8 +85,8 @@ def test_shifted_arg_to_lift(program_processor, lift_mode): out[1:, :] = inp[:-1, :] shape = (out.shape[0], out.shape[1]) - inp_s = gtx.np_as_located_field(IDim, JDim, origin={IDim: 0, JDim: 0})(inp[:, :]) - out_s = gtx.np_as_located_field(IDim, JDim)(np.zeros_like(inp[:, :])) + inp_s = gtx.as_field([IDim, JDim], inp[:, :], origin={IDim: 0, JDim: 0}) + out_s = gtx.as_field([IDim, JDim], np.zeros_like(inp[:, :])) run_processor( stencil_shifted_arg_to_lift[ @@ -123,8 +123,8 @@ def test_direct_deref(program_processor, lift_mode): inp = rng.uniform(size=(5, 7)) out = np.copy(inp) - inp_s = gtx.np_as_located_field(IDim, JDim)(inp) - out_s = gtx.np_as_located_field(IDim, JDim)(np.zeros_like(inp)) + inp_s = gtx.as_field([IDim, JDim], inp) + out_s = gtx.as_field([IDim, JDim], np.zeros_like(inp)) run_processor( fen_direct_deref, @@ -153,8 +153,8 @@ def test_vertical_shift_unstructured(program_processor): rng = np.random.default_rng() inp = rng.uniform(size=(1, k_size)) - inp_s = gtx.np_as_located_field(IDim, KDim)(inp) - out_s = gtx.np_as_located_field(IDim, KDim)(np.zeros_like(inp)) + inp_s = gtx.as_field([IDim, KDim], inp) + out_s = gtx.as_field([IDim, KDim], np.zeros_like(inp)) run_processor( vertical_shift[ diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 67b439507c..97a51508f5 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -55,16 +55,18 @@ def test_tuple_output(program_processor, stencil): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp1 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp2 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) out = ( - gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)), - gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)), + gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), + gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), ) dom = { @@ -98,27 +100,31 @@ def stencil(inp1, inp2, inp3, inp4): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp1 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp2 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp3 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp3 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp4 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp4 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) out = ( ( - gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)), - gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)), + gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), + gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), ), ( - gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)), - gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)), + gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), + gtx.as_field([IDim, JDim, KDim], np.zeros(shape)), ), ) @@ -166,15 +172,17 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp1 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp2 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - out1 = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) - out2 = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + out1 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) + out2 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) run_processor( fencil, @@ -215,19 +223,22 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp1 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp2 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp3 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp3 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - out1 = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) - out2 = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) - out3 = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + out1 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) + out2 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) + out3 = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) run_processor( fencil, @@ -259,15 +270,17 @@ def test_field_of_extra_dim_output(program_processor, stencil): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp1 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp2 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) out_np = np.zeros(shape + [2]) - out = gtx.np_as_located_field(IDim, JDim, KDim, None)(out_np) + out = gtx.as_field([IDim, JDim, KDim, None], out_np) dom = { IDim: range(0, shape[0]), @@ -292,14 +305,16 @@ def test_tuple_field_input(program_processor): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp1 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( + inp2 = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2])), ) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) dom = { IDim: range(0, shape[0]), @@ -322,8 +337,8 @@ def test_field_of_extra_dim_input(program_processor): inp2 = rng.normal(size=(shape[0], shape[1], shape[2])) inp = np.stack((inp1, inp2), axis=-1) - inp = gtx.np_as_located_field(IDim, JDim, KDim, None)(inp) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + inp = gtx.as_field([IDim, JDim, KDim, None], inp) + out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) dom = { IDim: range(0, shape[0]), @@ -353,20 +368,12 @@ def test_tuple_of_tuple_of_field_input(program_processor): shape = [5, 7, 9] rng = np.random.default_rng() - inp1 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])) - ) - inp2 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])) - ) - inp3 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])) - ) - inp4 = gtx.np_as_located_field(IDim, JDim, KDim)( - rng.normal(size=(shape[0], shape[1], shape[2])) - ) + inp1 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp2 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp3 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) + inp4 = gtx.as_field([IDim, JDim, KDim], rng.normal(size=(shape[0], shape[1], shape[2]))) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) dom = { IDim: range(0, shape[0]), @@ -393,11 +400,11 @@ def test_field_of_2_extra_dim_input(program_processor): shape = [5, 7, 9] rng = np.random.default_rng() - inp = gtx.np_as_located_field(IDim, JDim, KDim, None, None)( - rng.normal(size=(shape[0], shape[1], shape[2], 2, 2)) + inp = gtx.as_field( + [IDim, JDim, KDim, None, None], rng.normal(size=(shape[0], shape[1], shape[2], 2, 2)) ) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) dom = { IDim: range(0, shape[0]), diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 3eaefa76de..3f229ef389 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -17,8 +17,8 @@ import gt4py.next as gtx from gt4py.next import errors -from gt4py.next.program_processors.runners import roundtrip +import next_tests.exclusion_matrices as definitions from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( # noqa: F401 # fixtures cartesian_case, @@ -70,7 +70,7 @@ def test_allocate_const(cartesian_case): # noqa: F811 # fixtures assert b == 42.0 -@pytest.mark.parametrize("fieldview_backend", [roundtrip.executor]) +@pytest.mark.parametrize("fieldview_backend", [~definitions.ProgramBackendId.ROUNDTRIP]) def test_verify_fails_with_wrong_reference(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, addition, "a")() b = cases.allocate(cartesian_case, addition, "b")() @@ -81,7 +81,7 @@ def test_verify_fails_with_wrong_reference(cartesian_case): # noqa: F811 # fixt cases.verify(cartesian_case, addition, a, b, out=out, ref=wrong_ref) -@pytest.mark.parametrize("fieldview_backend", [roundtrip.executor]) +@pytest.mark.parametrize("fieldview_backend", [~definitions.ProgramBackendId.ROUNDTRIP]) def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, addition, "a").dtype(np.float32)() b = cases.allocate(cartesian_case, addition, "b")() @@ -91,7 +91,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures cases.verify(cartesian_case, addition, a, b, out=out, ref=a + b) -@pytest.mark.parametrize("fieldview_backend", [roundtrip.executor]) +@pytest.mark.parametrize("fieldview_backend", [~definitions.ProgramBackendId.ROUNDTRIP]) def test_verify_with_default_data_fails_with_wrong_reference( cartesian_case, # noqa: F811 # fixtures ): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 64fb238470..108ee25862 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -192,21 +192,17 @@ def test_setup(): class setup: cell_size = 14 k_size = 10 - z_alpha = gtx.np_as_located_field(Cell, KDim)( - np.random.default_rng().uniform(size=(cell_size, k_size + 1)) + z_alpha = gtx.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = gtx.np_as_located_field(Cell, KDim)( - np.random.default_rng().uniform(size=(cell_size, k_size)) - ) - z_q = gtx.np_as_located_field(Cell, KDim)( - np.random.default_rng().uniform(size=(cell_size, k_size)) - ) - w = gtx.np_as_located_field(Cell, KDim)( - np.random.default_rng().uniform(size=(cell_size, k_size)) + z_beta = gtx.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) + z_q = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + w = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = gtx.np_as_located_field(Cell, KDim)(np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = gtx.np_as_located_field(Cell, KDim)(np.zeros((cell_size, k_size))) + dummy = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() @@ -239,7 +235,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if fieldview_backend == roundtrip.executor: + if fieldview_backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( @@ -275,7 +271,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == roundtrip.executor: + if fieldview_backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 4e295e92af..829bc497cb 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -90,10 +90,12 @@ def test_anton_toy(program_processor, lift_mode): shape = [5, 7, 9] rng = np.random.default_rng() - inp = gtx.np_as_located_field(IDim, JDim, KDim, origin={IDim: 1, JDim: 1, KDim: 0})( + inp = gtx.as_field( + [IDim, JDim, KDim], rng.normal(size=(shape[0] + 2, shape[1] + 2, shape[2])), + origin={IDim: 1, JDim: 1, KDim: 0}, ) - out = gtx.np_as_located_field(IDim, JDim, KDim)(np.zeros(shape)) + out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) ref = naive_lap(inp) run_processor( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 04cf8c6f9c..d05b14d73d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -56,15 +56,17 @@ def shift_stencil(inp): ( shift_stencil, lambda inp: np.asarray(inp)[1:, 1:], - lambda shape: gtx.np_as_located_field(IDim, KDim)( - np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1]) + lambda shape: gtx.as_field( + [IDim, KDim], np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 1]) ), ), ( shift_stencil, lambda inp: np.asarray(inp)[1:, 2:], - lambda shape: gtx.np_as_located_field(IDim, KDim, origin={IDim: 0, KDim: 1})( - np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2]) + lambda shape: gtx.as_field( + [IDim, KDim], + np.fromfunction(lambda i, k: i * 10 + k, [shape[0] + 1, shape[1] + 2]), + origin={IDim: 0, KDim: 1}, ), ), ], @@ -81,11 +83,11 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): shape = [5, 7] inp = ( - gtx.np_as_located_field(IDim, KDim)(np.fromfunction(lambda i, k: i * 10 + k, shape)) + gtx.as_field([IDim, KDim], np.fromfunction(lambda i, k: i * 10 + k, shape)) if inp_fun is None else inp_fun(shape) ) - out = gtx.np_as_located_field(IDim, KDim)(np.zeros(shape)) + out = gtx.as_field([IDim, KDim], np.zeros(shape)) ref = ref_fun(inp) @@ -129,21 +131,21 @@ def k_level_condition_upper_tuple(k_idx, k_level): ( k_level_condition_lower, lambda inp: 0, - lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([[0], inp[:-1]]), ), ( k_level_condition_upper, lambda inp: inp.shape[0] - 1, - lambda k_size: gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([inp[1:], [0]]), ), ( k_level_condition_upper_tuple, lambda inp: inp[0].shape[0] - 1, lambda k_size: ( - gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), - gtx.np_as_located_field(KDim)(np.arange(k_size, dtype=np.int32)), + gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), + gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), ), lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), ), @@ -157,7 +159,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct inp = inp_function(k_size) ref = ref_function(inp) - out = gtx.np_as_located_field(KDim)(np.zeros((5,), dtype=np.int32)) + out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) run_processor( fun[{KDim: range(0, k_size)}], @@ -204,8 +206,8 @@ def ksum_fencil(i_size, k_start, k_end, inp, out): def test_ksum_scan(program_processor, lift_mode, kstart, reference): program_processor, validate = program_processor shape = [1, 7] - inp = gtx.np_as_located_field(IDim, KDim)(np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) - out = gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=inp.dtype)) + inp = gtx.as_field([IDim, KDim], np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) + out = gtx.as_field([IDim, KDim], np.zeros(shape, dtype=inp.dtype)) run_processor( ksum_fencil, @@ -241,8 +243,8 @@ def ksum_back_fencil(i_size, k_size, inp, out): def test_ksum_back_scan(program_processor, lift_mode): program_processor, validate = program_processor shape = [1, 7] - inp = gtx.np_as_located_field(IDim, KDim)(np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) - out = gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=inp.dtype)) + inp = gtx.as_field([IDim, KDim], np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) + out = gtx.as_field([IDim, KDim], np.zeros(shape, dtype=inp.dtype)) ref = np.asarray([[21, 21, 20, 18, 15, 11, 6]]) @@ -304,11 +306,11 @@ def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): program_processor, validate = program_processor pytest.xfail("structured dtype input/output currently unsupported") shape = [1, 7] - inp0 = gtx.np_as_located_field(IDim, KDim)(np.asarray([list(range(7))], dtype=np.float64)) - inp1 = gtx.np_as_located_field(IDim, KDim)(np.asarray([list(range(7))], dtype=np.int32)) + inp0 = gtx.as_field([IDim, KDim], np.asarray([list(range(7))], dtype=np.float64)) + inp1 = gtx.as_field([IDim, KDim], np.asarray([list(range(7))], dtype=np.int32)) out = ( - gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=np.float64)), - gtx.np_as_located_field(IDim, KDim)(np.zeros(shape, dtype=np.float32)), + gtx.as_field([IDim, KDim], np.zeros(shape, dtype=np.float64)), + gtx.as_field([IDim, KDim], np.zeros(shape, dtype=np.float32)), ) run_processor( @@ -348,9 +350,9 @@ def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor k_size = 10 - inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) - inp1 = gtx.np_as_located_field(KDim)(np.arange(0, k_size + 1)) - out = gtx.np_as_located_field(KDim)(np.zeros(k_size, dtype=inp0.dtype)) + inp0 = gtx.as_field([KDim], np.arange(0, k_size)) + inp1 = gtx.as_field([KDim], np.arange(0, k_size + 1)) + out = gtx.as_field([KDim], np.zeros(k_size, dtype=inp0.dtype)) ref = inp0.ndarray + inp1.ndarray[1:] run_processor( @@ -387,9 +389,9 @@ def test_different_vertical_sizes_with_origin(program_processor): program_processor, validate = program_processor k_size = 10 - inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) - inp1 = gtx.np_as_located_field(KDim, origin={KDim: 1})(np.arange(0, k_size + 1)) - out = gtx.np_as_located_field(KDim)(np.zeros(k_size, dtype=np.int64)) + inp0 = gtx.as_field([KDim], np.arange(0, k_size)) + inp1 = gtx.as_field([KDim], np.arange(0, k_size + 1), origin={KDim: 1}) + out = gtx.as_field([KDim], np.zeros(k_size, dtype=np.int64)) ref = np.asarray(inp0) + np.asarray(inp1)[:-1] run_processor( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 445b73548b..47867b9a64 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -138,10 +138,10 @@ def test_compute_zavgS(program_processor, lift_mode): program_processor, validate = program_processor setup = nabla_setup() - pp = gtx.np_as_located_field(Vertex)(setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.np_as_located_field(Edge), setup.S_fields)) + pp = gtx.as_field([Vertex], setup.input_field) + S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - zavgS = gtx.np_as_located_field(Edge)(np.zeros((setup.edges_size))) + zavgS = gtx.as_field([Edge], np.zeros((setup.edges_size))) e2v = gtx.NeighborTableOffsetProvider( AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 @@ -197,13 +197,13 @@ def test_compute_zavgS2(program_processor, lift_mode): program_processor, validate = program_processor setup = nabla_setup() - pp = gtx.np_as_located_field(Vertex)(setup.input_field) + pp = gtx.as_field([Vertex], setup.input_field) - S = tuple(gtx.np_as_located_field(Edge)(s) for s in setup.S_fields) + S = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) zavgS = ( - gtx.np_as_located_field(Edge)(np.zeros((setup.edges_size))), - gtx.np_as_located_field(Edge)(np.zeros((setup.edges_size))), + gtx.as_field([Edge], np.zeros((setup.edges_size))), + gtx.as_field([Edge], np.zeros((setup.edges_size))), ) e2v = gtx.NeighborTableOffsetProvider( @@ -236,13 +236,13 @@ def test_nabla(program_processor, lift_mode): pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") setup = nabla_setup() - sign = gtx.np_as_located_field(Vertex, V2EDim)(setup.sign_field) - pp = gtx.np_as_located_field(Vertex)(setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.np_as_located_field(Edge), setup.S_fields)) - vol = gtx.np_as_located_field(Vertex)(setup.vol_field) + sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) + pp = gtx.as_field([Vertex], setup.input_field) + S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + vol = gtx.as_field([Vertex], setup.vol_field) - pnabla_MXX = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) - pnabla_MYY = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) e2v = gtx.NeighborTableOffsetProvider( AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 @@ -294,13 +294,13 @@ def test_nabla2(program_processor, lift_mode): program_processor, validate = program_processor setup = nabla_setup() - sign = gtx.np_as_located_field(Vertex, V2EDim)(setup.sign_field) - pp = gtx.np_as_located_field(Vertex)(setup.input_field) - S_M = tuple(gtx.np_as_located_field(Edge)(s) for s in setup.S_fields) - vol = gtx.np_as_located_field(Vertex)(setup.vol_field) + sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) + pp = gtx.as_field([Vertex], setup.input_field) + S_M = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) + vol = gtx.as_field([Vertex], setup.vol_field) - pnabla_MXX = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) - pnabla_MYY = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) e2v = gtx.NeighborTableOffsetProvider( AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 @@ -375,13 +375,13 @@ def test_nabla_sign(program_processor, lift_mode): pytest.xfail("test is broken due to bad lift semantics in iterator IR") setup = nabla_setup() - is_pole_edge = gtx.np_as_located_field(Edge)(setup.is_pole_edge_field) - pp = gtx.np_as_located_field(Vertex)(setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.np_as_located_field(Edge), setup.S_fields)) - vol = gtx.np_as_located_field(Vertex)(setup.vol_field) + is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) + pp = gtx.as_field([Vertex], setup.input_field) + S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + vol = gtx.as_field([Vertex], setup.vol_field) - pnabla_MXX = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) - pnabla_MYY = gtx.np_as_located_field(Vertex)(np.zeros((setup.nodes_size))) + pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) e2v = gtx.NeighborTableOffsetProvider( AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index af70dd590f..8aabd18267 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -88,9 +88,9 @@ def test_hdiff(hdiff_reference, program_processor, lift_mode): inp, coeff, out = hdiff_reference shape = (out.shape[0], out.shape[1]) - inp_s = gtx.np_as_located_field(IDim, JDim, origin={IDim: 2, JDim: 2})(inp[:, :, 0]) - coeff_s = gtx.np_as_located_field(IDim, JDim)(coeff[:, :, 0]) - out_s = gtx.np_as_located_field(IDim, JDim)(np.zeros_like(coeff[:, :, 0])) + inp_s = gtx.as_field([IDim, JDim], inp[:, :, 0], origin={IDim: 2, JDim: 2}) + coeff_s = gtx.as_field([IDim, JDim], coeff[:, :, 0]) + out_s = gtx.as_field([IDim, JDim], np.zeros_like(coeff[:, :, 0])) run_processor( hdiff, program_processor, inp_s, coeff_s, out_s, shape[0], shape[1], lift_mode=lift_mode diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index a0471e8baa..29c82442ea 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -122,7 +122,7 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): gtfn.run_gtfn, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, - gtfn_formatters.format_sourcecode, + gtfn_formatters.format_cpp, ] and lift_mode == LiftMode.FORCE_INLINE ): @@ -134,7 +134,7 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape - as_3d_field = gtx.np_as_located_field(IDim, JDim, KDim) + as_3d_field = gtx.as_field.partial([IDim, JDim, KDim]) a_s = as_3d_field(a) b_s = as_3d_field(b) c_s = as_3d_field(c) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index d475fab3a8..6354e45451 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -52,11 +52,11 @@ def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings - return gtx.np_as_located_field(Edge)(np.arange(e2v_arr.shape[0], dtype=np.int32)) + return gtx.as_field([Edge], np.arange(e2v_arr.shape[0], dtype=np.int32)) def vertex_index_field(): # TODO replace by gtx.index_field once supported in bindings - return gtx.np_as_located_field(Vertex)(np.arange(v2e_arr.shape[0], dtype=np.int32)) + return gtx.as_field([Vertex], np.arange(v2e_arr.shape[0], dtype=np.int32)) @fundef @@ -87,7 +87,7 @@ def sum_edges_to_vertices_reduce(in_edges): def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): program_processor, validate = program_processor inp = edge_index_field() - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(sum(row) for row in v2e_arr)) run_processor( @@ -110,7 +110,7 @@ def map_neighbors(in_edges): def test_map_neighbors(program_processor, lift_mode): program_processor, validate = program_processor inp = edge_index_field() - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) run_processor( @@ -134,7 +134,7 @@ def map_make_const_list(in_edges): def test_map_make_const_list(program_processor, lift_mode): program_processor, validate = program_processor inp = edge_index_field() - out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) + out = gtx.as_field([Vertex], np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) run_processor( @@ -157,7 +157,7 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor, lift_mode): program_processor, validate = program_processor inp = vertex_index_field() - out = gtx.np_as_located_field(Cell)(np.zeros([9], dtype=inp.dtype)) + out = gtx.as_field([Cell], np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(v2e_arr[c[0]][0] for c in c2e_arr)) run_processor( @@ -183,9 +183,9 @@ def sparse_stencil(non_sparse, inp): def test_sparse_input_field(program_processor, lift_mode): program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) - inp = gtx.np_as_located_field(Vertex, V2EDim)(np.asarray([[1, 2, 3, 4]] * 9, dtype=np.int32)) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + non_sparse = gtx.as_field([Edge], np.zeros(18, dtype=np.int32)) + inp = gtx.as_field([Vertex, V2EDim], np.asarray([[1, 2, 3, 4]] * 9, dtype=np.int32)) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = np.ones([9]) * 10 @@ -206,9 +206,9 @@ def test_sparse_input_field(program_processor, lift_mode): def test_sparse_input_field_v2v(program_processor, lift_mode): program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) - inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + non_sparse = gtx.as_field([Edge], np.zeros(18, dtype=np.int32)) + inp = gtx.as_field([Vertex, V2VDim], v2v_arr) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(sum(row) for row in v2v_arr)) @@ -237,8 +237,8 @@ def slice_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields def test_slice_sparse(program_processor, lift_mode): program_processor, validate = program_processor - inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp = gtx.as_field([Vertex, V2VDim], v2v_arr) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[:, 1] @@ -265,8 +265,8 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") def test_slice_twice_sparse(program_processor, lift_mode): program_processor, validate = program_processor - inp = gtx.np_as_located_field(Vertex, V2VDim, V2VDim)(v2v_arr[v2v_arr]) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp = gtx.as_field([Vertex, V2VDim, V2VDim], v2v_arr[v2v_arr]) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[v2v_arr][:, 2, 1] run_processor( @@ -292,8 +292,8 @@ def shift_sliced_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields def test_shift_sliced_sparse(program_processor, lift_mode): program_processor, validate = program_processor - inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp = gtx.as_field([Vertex, V2VDim], v2v_arr) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[:, 1][v2v_arr][:, 0] @@ -320,8 +320,8 @@ def slice_shifted_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields def test_slice_shifted_sparse(program_processor, lift_mode): program_processor, validate = program_processor - inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp = gtx.as_field([Vertex, V2VDim], v2v_arr) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[:, 1][v2v_arr][:, 0] @@ -353,7 +353,7 @@ def lift_stencil(inp): def test_lift(program_processor, lift_mode): program_processor, validate = program_processor inp = vertex_index_field() - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) run_processor( @@ -376,8 +376,8 @@ def sparse_shifted_stencil(inp): @pytest.mark.uses_sparse_fields def test_shift_sparse_input_field(program_processor, lift_mode): program_processor, validate = program_processor - inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp = gtx.as_field([Vertex, V2VDim], v2v_arr) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) run_processor( @@ -415,9 +415,9 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): "Bug in bindings/compilation/caching: only the first program seems to be compiled." ) # observed in `cache.Strategy.PERSISTENT` mode inp = vertex_index_field() - inp_sparse = gtx.np_as_located_field(Edge, E2VDim)(e2v_arr) - out1 = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) - out2 = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp_sparse = gtx.as_field([Edge, E2VDim], e2v_arr) + out1 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) + out2 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) offset_provider = { "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), @@ -461,8 +461,8 @@ def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") - inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) - out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) + inp = gtx.as_field([Vertex, V2VDim], v2v_arr) + out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) ref = [] for row in v2v_arr: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py index c60079eaf1..d851c5560a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py @@ -32,10 +32,8 @@ def test_different_buffer_sizes(): out_nx = 5 out_ny = 5 - inp = gtx.np_as_located_field(IDim, JDim)( - np.reshape(np.arange(nx * ny, dtype=np.int32), (nx, ny)) - ) - out = gtx.np_as_located_field(IDim, JDim)(np.zeros((out_nx, out_ny), dtype=np.int32)) + inp = gtx.as_field([IDim, JDim], np.reshape(np.arange(nx * ny, dtype=np.int32), (nx, ny))) + out = gtx.as_field([IDim, JDim], np.zeros((out_nx, out_ny), dtype=np.int32)) @gtx.field_operator(backend=gtfn.run_gtfn) def copy(inp: gtx.Field[[IDim, JDim], gtx.int32]) -> gtx.Field[[IDim, JDim], gtx.int32]: diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 747431599a..b43eeb3f91 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -14,16 +14,13 @@ from __future__ import annotations -from dataclasses import dataclass +import dataclasses import pytest import gt4py.next as gtx -from gt4py import eve -from gt4py.next.iterator import ir as itir, pretty_parser, pretty_printer, runtime, transforms +from gt4py.next.iterator import runtime, transforms from gt4py.next.program_processors import processor_interface as ppi -from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters, lisp, type_check -from gt4py.next.program_processors.runners import double_roundtrip, gtfn, roundtrip try: @@ -35,6 +32,7 @@ raise e import next_tests +import next_tests.exclusion_matrices as definitions @pytest.fixture( @@ -49,60 +47,48 @@ def lift_mode(request): return request.param -class _RemoveITIRSymTypes(eve.NodeTranslator): - def visit_Sym(self, node: itir.Sym) -> itir.Sym: - return itir.Sym(id=node.id, dtype=None, kind=None) - - -@ppi.program_formatter -def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str: - # remove types from ITIR as they are not supported for the roundtrip - root = _RemoveITIRSymTypes().visit(root) - pretty = pretty_printer.pformat(root) - parsed = pretty_parser.pparse(pretty) - assert parsed == root - return pretty - - OPTIONAL_PROCESSORS = [] if dace_iterator: - OPTIONAL_PROCESSORS.append((dace_iterator.run_dace_iterator, True)) + OPTIONAL_PROCESSORS.append((definitions.OptionalProgramBackendId.DACE_CPU, True)) @pytest.fixture( params=[ - # (processor, do_validate) (None, True), - (lisp.format_lisp, False), - (pretty_format_and_check, False), - (roundtrip.executor, True), - (type_check.check, False), - (double_roundtrip.executor, True), - (gtfn.run_gtfn, True), - (gtfn.run_gtfn_imperative, True), - (gtfn.run_gtfn_with_temporaries, True), - (gtfn_formatters.format_sourcecode, False), + (definitions.ProgramBackendId.ROUNDTRIP, True), + (definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), + (definitions.ProgramBackendId.GTFN_CPU, True), + (definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), + (definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), + (definitions.ProgramFormatterId.LISP_FORMATTER, False), + (definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), + (definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), + (definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), ] + OPTIONAL_PROCESSORS, - ids=lambda p: next_tests.get_processor_id(p[0]), + ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) -def program_processor(request): +def program_processor(request) -> tuple[ppi.ProgramProcessor, bool]: """ Fixture creating program processors on-demand for tests. Notes: Check ADR 15 for details on the test-exclusion matrices. """ - backend, _ = request.param - backend_id = next_tests.get_processor_id(backend) + processor_id, is_backend = request.param + if processor_id is None: + return None, is_backend + + processor = processor_id.load() + assert is_backend == ppi.is_program_backend(processor) for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( - backend_id, [] + processor_id, [] ): if request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=backend_id)) + skip_mark(msg.format(marker=marker, backend=processor_id)) - return request.param + return processor, is_backend def run_processor( @@ -119,7 +105,7 @@ def run_processor( raise TypeError(f"program processor kind not recognized: {processor}!") -@dataclass +@dataclasses.dataclass class DummyConnectivity: max_neighbors: int has_skip_values: int diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 95093c8307..8a4b4cbd84 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -20,8 +20,8 @@ import numpy as np import pytest -from gt4py.next import Dimension, common -from gt4py.next.common import Domain, UnitRange +from gt4py.next import common, constructors +from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index ed7daa3cff..232995be58 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -50,14 +50,15 @@ def test_deduce_domain(): def test_embedded_error_on_wrong_domain(): dom = CartesianDomain([("I", range(1))]) - out = gtx.np_as_located_field(I)( + out = gtx.as_field( + [I], np.zeros( 1, - ) + ), ) with pytest.raises(RuntimeError, match="expected `UnstructuredDomain`"): foo[dom]( - gtx.np_as_located_field(I)(np.zeros((1,))), + gtx.as_field([I], np.zeros((1,))), out=out, offset_provider={"bar": connectivity}, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index ae5f582e47..4e865452f6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -57,7 +57,7 @@ def fencil_example(): ) IDim = gtx.Dimension("I") params = [ - gtx.np_as_located_field(IDim)(np.empty((1,), dtype=np.float32)), + gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14), ] return fencil, params diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index 6cd8d43c3b..05e982cf0c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -14,15 +14,49 @@ import pytest +import gt4py.next.allocators as next_allocators from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.processor_interface import ( + ProgramBackend, ProgramExecutor, ProgramFormatter, + ProgramProcessor, ensure_processor_kind, + is_processor_kind, + is_program_backend, + make_program_processor, program_formatter, ) +def test_make_program_processor(dummy_formatter): + def my_func(program: itir.FencilDefinition, *args, **kwargs) -> None: + return None + + processor = make_program_processor(my_func, ProgramExecutor) + assert is_processor_kind(processor, ProgramExecutor) + assert processor.__name__ == my_func.__name__ + assert processor(None) == my_func(None) + + def other_func(program: itir.FencilDefinition, *args, **kwargs) -> str: + return f"{args}, {kwargs}" + + processor = make_program_processor( + other_func, ProgramFormatter, name="new_name", accept_args=2, accept_kwargs=["a", "b"] + ) + assert is_processor_kind(processor, ProgramFormatter) + assert processor.__name__ == "new_name" + assert processor(None) == other_func(None) + assert processor(1, 2, a="A", b="B") == other_func(1, 2, a="A", b="B") + assert processor(1, 2, 3, 4, a="A", b="B", c="C") != other_func(1, 2, 3, 4, a="A", b="B", c="C") + + with pytest.raises(ValueError, match="accepted arguments cannot be a negative number"): + make_program_processor(my_func, ProgramFormatter, accept_args=-1) + + with pytest.raises(ValueError, match="invalid list of keyword argument names"): + make_program_processor(my_func, ProgramFormatter, accept_kwargs=["a", None]) + + @pytest.fixture def dummy_formatter(): @program_formatter @@ -47,3 +81,22 @@ def undecorated_formatter(fencil: itir.FencilDefinition, *args, **kwargs) -> str def test_wrong_processor_type_is_caught_at_runtime(dummy_formatter): with pytest.raises(TypeError, match="is not a ProgramExecutor"): ensure_processor_kind(dummy_formatter, ProgramExecutor) + + +def test_is_program_backend(): + class DummyProgramExecutor(ProgramExecutor): + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: + return None + + assert not is_program_backend(DummyProgramExecutor()) + + class DummyAllocatorFactory: + __gt_allocator__ = next_allocators.StandardCPUFieldBufferAllocator() + + assert not is_program_backend(DummyAllocatorFactory()) + + class DummyBackend(DummyProgramExecutor, DummyAllocatorFactory): + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: + return None + + assert is_program_backend(DummyBackend()) diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py new file mode 100644 index 0000000000..456654c1d0 --- /dev/null +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -0,0 +1,193 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from collections.abc import Sequence +from typing import Optional + +import pytest + +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators +import gt4py.next.common as common +import gt4py.storage.allocators as core_allocators + + +class DummyAllocator(next_allocators.FieldBufferAllocatorProtocol): + __gt_device_type__ = core_defs.DeviceType.CPU + + def __gt_allocate__( + self, + domain: common.Domain, + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int = 0, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + pass + + +class DummyAllocatorFactory(next_allocators.FieldBufferAllocatorFactoryProtocol): + __gt_allocator__ = DummyAllocator() + + +def test_is_field_allocator(): + # Test with a field allocator + allocator = DummyAllocator() + assert next_allocators.is_field_allocator(allocator) + + # Test with an invalid object + invalid_obj = "not an allocator" + assert not next_allocators.is_field_allocator(invalid_obj) + + +def test_is_field_allocator_for(): + # Test with a valid field allocator for the specified device + assert next_allocators.is_field_allocator_for(DummyAllocator(), core_defs.DeviceType.CPU) + + # Test with a valid field allocator for a different device + assert not next_allocators.is_field_allocator_for(DummyAllocator(), core_defs.DeviceType.CUDA) + + # Test with an invalid field allocator + assert not next_allocators.is_field_allocator_for("not an allocator", core_defs.DeviceType.CPU) + + +def test_is_field_allocator_factory(): + # Test with a field allocator factory + allocator_factory = DummyAllocatorFactory() + assert next_allocators.is_field_allocator_factory(allocator_factory) + + # Test with an invalid object + invalid_obj = "not an allocator" + assert not next_allocators.is_field_allocator_factory(invalid_obj) + + +def test_is_field_allocator_factory_for(): + # Test with a field allocator factory that matches the device type + allocator_factory = DummyAllocatorFactory() + assert next_allocators.is_field_allocator_factory_for( + allocator_factory, core_defs.DeviceType.CPU + ) + + # Test with a field allocator factory that doesn't match the device type + allocator_factory = DummyAllocatorFactory() + assert not next_allocators.is_field_allocator_factory_for( + allocator_factory, core_defs.DeviceType.CUDA + ) + + # Test with an object that is not a field allocator factory + invalid_obj = "not an allocator factory" + assert not next_allocators.is_field_allocator_factory_for(invalid_obj, core_defs.DeviceType.CPU) + + +def test_get_allocator(): + # Test with a field allocator + allocator = DummyAllocator() + assert next_allocators.get_allocator(allocator) == allocator + + # Test with a field allocator factory + allocator_factory = DummyAllocatorFactory() + assert next_allocators.get_allocator(allocator_factory) == allocator_factory.__gt_allocator__ + + # Test with a default allocator + default_allocator = DummyAllocator() + assert next_allocators.get_allocator(None, default=default_allocator) == default_allocator + + # Test with an invalid object and no default allocator + invalid_obj = "not an allocator" + assert next_allocators.get_allocator(invalid_obj) is None + + with pytest.raises( + TypeError, + match=f"Object {invalid_obj} is neither a field allocator nor a field allocator factory", + ): + next_allocators.get_allocator(invalid_obj, strict=True) + + +def test_horizontal_first_layout_mapper(): + from gt4py.next.allocators import horizontal_first_layout_mapper + + # Test with only horizontal dimensions + dims = [ + common.Dimension("D0", common.DimensionKind.HORIZONTAL), + common.Dimension("D1", common.DimensionKind.HORIZONTAL), + common.Dimension("D2", common.DimensionKind.HORIZONTAL), + ] + expected_layout_map = core_allocators.BufferLayoutMap((2, 1, 0)) + assert horizontal_first_layout_mapper(dims) == expected_layout_map + + # Test with no horizontal dimensions + dims = [ + common.Dimension("D0", common.DimensionKind.VERTICAL), + common.Dimension("D1", common.DimensionKind.LOCAL), + common.Dimension("D2", common.DimensionKind.VERTICAL), + ] + expected_layout_map = core_allocators.BufferLayoutMap((2, 0, 1)) + assert horizontal_first_layout_mapper(dims) == expected_layout_map + + # Test with a mix of dimensions + dims = [ + common.Dimension("D2", common.DimensionKind.LOCAL), + common.Dimension("D0", common.DimensionKind.HORIZONTAL), + common.Dimension("D1", common.DimensionKind.VERTICAL), + ] + expected_layout_map = core_allocators.BufferLayoutMap((0, 2, 1)) + assert horizontal_first_layout_mapper(dims) == expected_layout_map + + +class TestInvalidFieldBufferAllocator: + def test_allocate(self): + allocator = next_allocators.InvalidFieldBufferAllocator( + core_defs.DeviceType.CPU, ValueError("test error") + ) + I = common.Dimension("I") + J = common.Dimension("J") + domain = common.domain(((I, (2, 4)), (J, (3, 5)))) + dtype = float + with pytest.raises(ValueError, match="test error"): + allocator.__gt_allocate__(domain, dtype) + + +def test_allocate(): + from gt4py.next.allocators import StandardCPUFieldBufferAllocator, allocate + + I = common.Dimension("I") + J = common.Dimension("J") + domain = common.domain(((I, (0, 2)), (J, (0, 3)))) + dtype = core_defs.dtype(float) + + # Test with a explicit field allocator + allocator = StandardCPUFieldBufferAllocator() + tensor_buffer = allocate(domain, dtype, allocator=allocator) + assert tensor_buffer.shape == domain.shape + assert tensor_buffer.dtype == dtype + assert tensor_buffer.device == core_defs.Device(core_defs.DeviceType.CPU, 0) + + # Test with a device + device = core_defs.Device(core_defs.DeviceType.CPU, 0) + tensor_buffer = allocate(domain, dtype, device=device) + assert tensor_buffer.shape == domain.shape + assert tensor_buffer.dtype == dtype + assert tensor_buffer.device == core_defs.Device(core_defs.DeviceType.CPU, 0) + + # Test with both allocator and device + with pytest.raises(ValueError, match="are incompatible"): + allocate( + domain, + dtype, + allocator=allocator, + device=core_defs.Device(core_defs.DeviceType.CUDA, 0), + ) + + # Test with no device or allocator + with pytest.raises(ValueError, match="No 'device' or 'allocator' specified"): + allocate(domain, dtype) diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py new file mode 100644 index 0000000000..e8b070f0c0 --- /dev/null +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -0,0 +1,175 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest + +from gt4py import next as gtx +from gt4py._core import definitions as core_defs +from gt4py.next import allocators as next_allocators, common, float32 +from gt4py.next.program_processors.runners import roundtrip + +from next_tests.integration_tests import cases + + +I = gtx.Dimension("I") +J = gtx.Dimension("J") +K = gtx.Dimension("K") + +sizes = {I: 10, J: 10, K: 10} + + +# TODO: parametrize with gpu backend and compare with cupy array +@pytest.mark.parametrize( + "allocator, device", + [ + [next_allocators.StandardCPUFieldBufferAllocator(), None], + [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], + ], +) +def test_empty(allocator, device): + ref = np.empty([sizes[I], sizes[J]]).astype(gtx.float32) + a = gtx.empty( + domain={I: range(sizes[I]), J: range(sizes[J])}, + dtype=core_defs.dtype(np.float32), + allocator=allocator, + device=device, + ) + assert a.shape == ref.shape + + +# TODO: parametrize with gpu backend and compare with cupy array +@pytest.mark.parametrize( + "allocator, device", + [ + [next_allocators.StandardCPUFieldBufferAllocator(), None], + [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], + ], +) +def test_zeros(allocator, device): + a = gtx.zeros( + common.Domain( + dims=(I, J), ranges=(common.UnitRange(0, sizes[I]), common.UnitRange(0, sizes[J])) + ), + dtype=core_defs.dtype(np.float32), + allocator=allocator, + device=device, + ) + ref = np.zeros((sizes[I], sizes[J])).astype(gtx.float32) + + assert np.array_equal(a.ndarray, ref) + + +# TODO: parametrize with gpu backend and compare with cupy array +@pytest.mark.parametrize( + "allocator, device", + [ + [next_allocators.StandardCPUFieldBufferAllocator(), None], + [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], + ], +) +def test_ones(allocator, device): + a = gtx.ones( + common.Domain(dims=(I, J), ranges=(common.UnitRange(0, 10), common.UnitRange(0, 10))), + dtype=core_defs.dtype(np.float32), + allocator=allocator, + device=device, + ) + ref = np.ones((sizes[I], sizes[J])).astype(gtx.float32) + + assert np.array_equal(a.ndarray, ref) + + +# TODO: parametrize with gpu backend and compare with cupy array +@pytest.mark.parametrize( + "allocator, device", + [ + [next_allocators.StandardCPUFieldBufferAllocator(), None], + [None, core_defs.Device(core_defs.DeviceType.CPU, 0)], + ], +) +def test_full(allocator, device): + a = gtx.full( + domain={I: range(sizes[I] - 2), J: (sizes[J] - 2)}, + fill_value=42.0, + dtype=core_defs.dtype(np.float32), + allocator=allocator, + device=device, + ) + ref = np.full((sizes[I] - 2, sizes[J] - 2), 42.0).astype(gtx.float32) + + assert np.array_equal(a.ndarray, ref) + + +def test_as_field(): + ref = np.random.rand(sizes[I]).astype(gtx.float32) + a = gtx.as_field([I], ref) + assert np.array_equal(a.ndarray, ref) + + +def test_as_field_domain(): + ref = np.random.rand(sizes[I] - 1, sizes[J] - 1).astype(gtx.float32) + domain = common.Domain( + dims=(I, J), + ranges=(common.UnitRange(0, sizes[I] - 1), common.UnitRange(0, sizes[J] - 1)), + ) + a = gtx.as_field(domain, ref) + assert np.array_equal(a.ndarray, ref) + + +def test_as_field_origin(): + data = np.random.rand(sizes[I], sizes[J]).astype(gtx.float32) + a = gtx.as_field([I, J], data, origin={I: 1, J: 2}) + domain_range = [(val.start, val.stop) for val in a.domain.ranges] + assert np.allclose(domain_range, [(-1, 9), (-2, 8)]) + + +# check that `as_field()` domain is correct depending on data origin and domain itself +def test_field_wrong_dims(): + with pytest.raises( + ValueError, + match=(r"Cannot construct `Field` from array of shape"), + ): + gtx.as_field([I, J], np.random.rand(sizes[I]).astype(gtx.float32)) + + +def test_field_wrong_domain(): + with pytest.raises( + ValueError, + match=(r"Cannot construct `Field` from array of shape"), + ): + domain = common.Domain( + dims=(I, J), + ranges=(common.UnitRange(0, sizes[I] - 1), common.UnitRange(0, sizes[J] - 1)), + ) + gtx.as_field(domain, np.random.rand(sizes[I], sizes[J]).astype(gtx.float32)) + + +def test_field_wrong_origin(): + with pytest.raises( + ValueError, + match=(r"Origin keys {'J'} not in domain"), + ): + gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), origin={"J": 0}) + + with pytest.raises( + ValueError, + match=(r"Cannot specify origin for domain I"), + ): + gtx.as_field("I", np.random.rand(sizes[J]).astype(gtx.float32), origin={"J": 0}) + + +@pytest.mark.xfail(reason="aligned_index not supported yet") +def test_aligned_index(): + gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), aligned_index=[I, 0]) From c8ff8ed3160164e6d1d1ec74f766a2bcd1366cfd Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 15 Nov 2023 16:47:29 +0100 Subject: [PATCH 30/67] bug[next] Fix broken gpu tox setup (#1358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And update gpu test which was broken in refactoring. --------- Co-authored-by: Rico Häuselmann --- .../ffront_tests/test_gpu_backend.py | 24 +++++++------------ tox.ini | 4 ++-- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py index 80e9a8e07a..7054597831 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py @@ -15,6 +15,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.program_processors.runners import dace_iterator, gtfn from next_tests.integration_tests import cases @@ -26,26 +27,19 @@ @pytest.mark.requires_gpu @pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace_gpu, gtfn.run_gtfn_gpu]) -def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures +def test_copy(fieldview_backend): # noqa: F811 # fixtures import cupy as cp @gtx.field_operator(backend=fieldview_backend) def testee(a: cases.IJKField) -> cases.IJKField: return a - inp_arr = cp.full(shape=(3, 4, 5), fill_value=3, dtype=cp.int32) - outp_arr = cp.zeros_like(inp_arr) - inp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], inp_arr) - outp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], outp_arr) - - testee(inp, out=outp, offset_provider={}) - assert cp.allclose(inp_arr, outp_arr) - - inp_field = gtx.full( - [cases.IDim, cases.JDim, cases.KDim], fill_value=3, allocator=fieldview_backend - ) - out_field = gtx.zeros( - [cases.IDim, cases.JDim, cases.KDim], outp_arr, allocator=fieldview_backend - ) + domain = { + cases.IDim: common.unit_range(3), + cases.JDim: common.unit_range(4), + cases.KDim: common.unit_range(5), + } + inp_field = gtx.full(domain, fill_value=3, allocator=fieldview_backend, dtype=cp.int32) + out_field = gtx.zeros(domain, allocator=fieldview_backend, dtype=cp.int32) testee(inp_field, out=out_field, offset_provider={}) assert cp.allclose(inp_field.ndarray, out_field.ndarray) diff --git a/tox.ini b/tox.ini index 18a6ff8e84..5b644e7d97 100644 --- a/tox.ini +++ b/tox.ini @@ -82,9 +82,9 @@ set_env = PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} commands = nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - nomesh-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests + nomesh-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - atlas-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests + # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next [testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] From b8cda74e2eade6d2cfb8c9ee175e456dafa8adc8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 16 Nov 2023 11:38:10 +0100 Subject: [PATCH 31/67] feat[next]: add `where` to embedded field view (#1316) - unifies unary and binary builtin to general nary in NdArrayField - special case for `where` with tuples --- src/gt4py/next/embedded/nd_array_field.py | 111 ++++++++---------- src/gt4py/next/ffront/fbuiltins.py | 23 +++- .../embedded_tests/test_nd_array_field.py | 54 +++++++++ 3 files changed, 128 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 527197e0bc..ea88948841 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,6 +15,8 @@ from __future__ import annotations import dataclasses +import functools +import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar @@ -39,40 +41,38 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] -def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_unary_op(a: NdArrayField) -> common.Field: - xp = a.__class__.array_ns +def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]: + def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: + first = fields[0] + assert isinstance(first, NdArrayField) + xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - new_data = op(a.ndarray) - return a.__class__.from_array(new_data, domain=a.domain) - - _builtin_unary_op.__name__ = builtin_name - return _builtin_unary_op - - -def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_binary_op(a: NdArrayField, b: common.Field) -> common.Field: - xp = a.__class__.array_ns - op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): - if not a.domain == b.domain: - domain_intersection = a.domain & b.domain - a_broadcasted = _broadcast(a, domain_intersection.dims) - b_broadcasted = _broadcast(b, domain_intersection.dims) - a_slices = _get_slices_from_domain_slice(a_broadcasted.domain, domain_intersection) - b_slices = _get_slices_from_domain_slice(b_broadcasted.domain, domain_intersection) - new_data = op(a_broadcasted.ndarray[a_slices], b_broadcasted.ndarray[b_slices]) - return a.__class__.from_array(new_data, domain=domain_intersection) - new_data = op(a.ndarray, xp.asarray(b.ndarray)) - else: - assert isinstance(b, core_defs.SCALAR_TYPES) - new_data = op(a.ndarray, b) - - return a.__class__.from_array(new_data, domain=a.domain) - - _builtin_binary_op.__name__ = builtin_name - return _builtin_binary_op + domain_intersection = functools.reduce( + operator.and_, + [f.domain for f in fields if common.is_field(f)], + common.Domain(dims=tuple(), ranges=tuple()), + ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] + for f in fields: + if common.is_field(f): + if f.domain == domain_intersection: + transformed.append(xp.asarray(f.ndarray)) + else: + f_broadcasted = _broadcast(f, domain_intersection.dims) + f_slices = _get_slices_from_domain_slice( + f_broadcasted.domain, domain_intersection + ) + transformed.append(xp.asarray(f_broadcasted.ndarray[f_slices])) + else: + assert core_defs.is_scalar_type(f) + transformed.append(f) + + new_data = op(*transformed) + return first.__class__.from_array(new_data, domain=domain_intersection) + + _builtin_op.__name__ = builtin_name + return _builtin_op _Value: TypeAlias = common.Field | core_defs.ScalarT @@ -174,56 +174,50 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __call__ = None # type: ignore[assignment] # TODO: remap - __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") + __abs__ = _make_builtin("abs", "abs") - __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __neg__ = _make_builtin("neg", "negative") - __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_builtin("add", "add") - __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") + __pos__ = _make_builtin("pos", "positive") - __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") + __sub__ = __rsub__ = _make_builtin("sub", "subtract") - __mul__ = __rmul__ = _make_binary_array_field_intrinsic_func("mul", "multiply") + __mul__ = __rmul__ = _make_builtin("mul", "multiply") - __truediv__ = __rtruediv__ = _make_binary_array_field_intrinsic_func("div", "divide") + __truediv__ = __rtruediv__ = _make_builtin("div", "divide") - __floordiv__ = __rfloordiv__ = _make_binary_array_field_intrinsic_func( - "floordiv", "floor_divide" - ) + __floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide") - __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") + __pow__ = _make_builtin("pow", "power") - __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") + __mod__ = __rmod__ = _make_builtin("mod", "mod") def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( - self, other - ) + return _make_builtin("logical_and", "logical_and")(self, other) raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") __rand__ = __and__ def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + return _make_builtin("logical_or", "logical_or")(self, other) raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") __ror__ = __or__ def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( - self, other - ) + return _make_builtin("logical_xor", "logical_xor")(self, other) raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") __rxor__ = __xor__ def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + return _make_builtin("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") def _slice( @@ -241,7 +235,7 @@ def _slice( return new_domain, slice_ -# -- Specialized implementations for intrinsic operations on array fields -- +# -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] @@ -254,19 +248,18 @@ def _slice( ): if name in ["abs", "power", "gamma"]: continue - NdArrayField.register_builtin_func( - getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name) - ) + NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name)) NdArrayField.register_builtin_func( - fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined] + fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined] + fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] + fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined] ) +NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) def _np_cp_setitem( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 52aae34b3f..13c21eb516 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -132,6 +132,27 @@ def builtin_function(fun: Callable[_P, _R]) -> BuiltInFunction[_R, _P]: return BuiltInFunction(fun) +MaskT = TypeVar("MaskT", bound=Field) +FieldT = TypeVar("FieldT", bound=Union[Field, gt4py_defs.Scalar, Tuple]) + + +class WhereBuiltinFunction( + BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] +): + def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: + if isinstance(true_field, tuple) or isinstance(false_field, tuple): + if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): + raise ValueError( + f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages + ) + if len(true_field) != len(false_field): + raise ValueError( + "Tuple of different size not allowed." + ) # TODO(havogt) find a strategy to unify parsing and embedded error messages + return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return super().__call__(mask, true_field, false_field) + + @builtin_function def neighbor_sum( field: Field, @@ -164,7 +185,7 @@ def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /) raise NotImplementedError() -@builtin_function +@WhereBuiltinFunction def where( mask: Field, true_field: Field | gt4py_defs.ScalarT | Tuple, diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 8a4b4cbd84..49aeece87e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -98,6 +98,60 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) +def test_where_builtin(nd_array_implementation): + cond = np.asarray([True, False]) + true_ = np.asarray([1.0, 2.0], dtype=np.float32) + false_ = np.asarray([3.0, 4.0], dtype=np.float32) + + field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]] + expected = np.where(cond, true_, false_) + + result = fbuiltins.where(*field_inputs) + assert np.allclose(result.ndarray, expected) + + +def test_where_builtin_different_domain(nd_array_implementation): + cond = np.asarray([True, False]) + true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32) + + cond_field = common.field( + nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2}) + ) + true_field = common.field( + nd_array_implementation.asarray(true_), + domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}), + ) + false_field = common.field( + nd_array_implementation.asarray(false_), + domain=common.domain({JDim: common.UnitRange(-1, 3)}), + ) + + expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1]) + + result = fbuiltins.where(cond_field, true_field, false_field) + assert np.allclose(result.ndarray, expected) + + +def test_where_builtin_with_tuple(nd_array_implementation): + cond = np.asarray([True, False]) + true0 = np.asarray([1.0, 2.0], dtype=np.float32) + false0 = np.asarray([3.0, 4.0], dtype=np.float32) + true1 = np.asarray([11.0, 12.0], dtype=np.float32) + false1 = np.asarray([13.0, 14.0], dtype=np.float32) + + expected0 = np.where(cond, true0, false0) + expected1 = np.where(cond, true1, false1) + + cond_field = _make_field(cond, nd_array_implementation, dtype=bool) + field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1]) + field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1]) + + result = fbuiltins.where(cond_field, field_true, field_false) + assert np.allclose(result[0].ndarray, expected0) + assert np.allclose(result[1].ndarray, expected1) + + def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] From 87832eddc7ee0156ab97290b168e2499d7ab541b Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 16 Nov 2023 13:29:38 +0100 Subject: [PATCH 32/67] presentation slides --- docs/user/next/presentation_slides.md | 411 ++++++++++++++++++++++++++ docs/user/next/scan_operator.png | Bin 0 -> 8760 bytes docs/user/next/simple_offset.png | Bin 0 -> 10292 bytes 3 files changed, 411 insertions(+) create mode 100644 docs/user/next/presentation_slides.md create mode 100644 docs/user/next/scan_operator.png create mode 100644 docs/user/next/simple_offset.png diff --git a/docs/user/next/presentation_slides.md b/docs/user/next/presentation_slides.md new file mode 100644 index 0000000000..87cd2b7787 --- /dev/null +++ b/docs/user/next/presentation_slides.md @@ -0,0 +1,411 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GT4Py workshop + ++++ + +## GT4Py: GridTools for Python + +GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. + +GT4Py is part of the GridTools framework: a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. + +**NOTE:** The `gt4py.next` subpackage contains a new and currently experimental version of GT4Py. + +## Description + +GT4Py is a Python library for expressing computational motifs as found in weather and climate applications. + +These computations are expressed in a domain specific language (GTScript) which is translated to high-performance implementations for CPUs and GPUs. + +The DSL expresses computations on a 3-dimensional Cartesian grid. The horizontal axes are always computed in parallel, while the vertical can be iterated in sequential, forward or backward, order. + +In addition, GT4Py provides functions to allocate arrays with memory layout suited for a particular backend. + +The following backends are supported: + +- `numpy`: Pure-Python backend +- `gt:cpu_ifirst`: GridTools C++ CPU backend using `I`-first data ordering +- `gt:cpu_kfirst`: GridTools C++ CPU backend using `K`-first data ordering +- `gt:gpu`: GridTools backend for CUDA +- `cuda`: CUDA backend minimally using utilities from GridTools +- `dace:cpu`: Dace code-generated CPU backend +- `dace:gpu`: Dace code-generated GPU backend + ++++ + +## Installation + +You can install the library directly from GitHub using pip: + +```{raw-cell} +pip install --upgrade git+https://github.com/gridtools/gt4py.git +``` + +```{code-cell} ipython3 +import warnings +warnings.filterwarnings('ignore') +``` + +```{code-cell} ipython3 +import numpy as np +import gt4py.next as gtx +from gt4py.next import float64, neighbor_sum, where +from gt4py.next.common import DimensionKind +``` + +## Key concepts and application structure + +- [Fields](#Fields), +- [Field operators](#Field-operators), and +- [Programs](#Programs). + ++++ + +### Fields +Fields are **multi-dimensional array** defined over a set of dimensions and a dtype: `gtx.Field[[dimensions], dtype]` + +The `as_field` builtin is used to define fields + +```{code-cell} ipython3 +CellDim = gtx.Dimension("Cell") +KDim = gtx.Dimension("K", kind=DimensionKind.VERTICAL) +grid_shape = (5, 6) +a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) +b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) + +print("a definition: \n {}".format(a)) +print("a array: \n {}".format(np.asarray(a))) +print("b array: \n {}".format(np.asarray(b))) +``` + +### Field operators + +Field operators perform operations on a set of fields, i.e. elementwise addition or reduction along a dimension. + +They are written as Python functions by using the `@field_operator` decorator. + +```{code-cell} ipython3 +@gtx.field_operator +def add(a: gtx.Field[[CellDim, KDim], float64], + b: gtx.Field[[CellDim, KDim], float64]) -> gtx.Field[[CellDim, KDim], float64]: + return a + b +``` + +Direct calls to field operators require two additional arguments: +- `out`: a field to write the return value to +- `offset_provider`: empty dict for now, explanation will follow + +```{code-cell} ipython3 +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +add(a, b, out=result, offset_provider={}) + +print("result array \n {}".format(np.asarray(result))) +``` + +### Programs + ++++ + +Programs are used to call field operators to mutate their arguments. + +They are written as Python functions by using the `@program` decorator. + +This example below calls the `add` field operator twice: + +```{code-cell} ipython3 +# @gtx.field_operator +# def add(a, b): +# return a + b + +@gtx.program +def run_add(a : gtx.Field[[CellDim, KDim], float64], + b : gtx.Field[[CellDim, KDim], float64], + result : gtx.Field[[CellDim, KDim], float64]): + add(a, b, out=result) # 2.0 + 3.0 = 5.0 + add(b, result, out=result) # 5.0 + 3.0 = 8.0 +``` + +```{code-cell} ipython3 +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +run_add(a, b, result, offset_provider={}) + +print("result array: \n {}".format(np.asarray(result))) +``` + +The fields in the subsequent code snippets are 1-dimensional, either over the cells or over the edges. The corresponding named dimensions are thus the following: + ++++ + +### Offsets +Fields can be offset by a predefined number of indices. + +Take an array with values ranging from 0 to 5: + +```{code-cell} ipython3 +a_off = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) + +print("a_off array: \n {}".format(np.asarray(a_off))) +``` + +Visually, offsetting this field by 1 would result in the following: + +| ![Coff](simple_offset.png) | +| :------------------------: | +| _CellDim Offset (Coff)_ | + ++++ + +Fields can be offeset by a predefined number of indices. + +Take an array with values ranging from 0 to 5: + +```{code-cell} ipython3 +Coff = gtx.FieldOffset("Coff", source=CellDim, target=(CellDim,)) + +@gtx.field_operator +def a_offset(a_off: gtx.Field[[CellDim], float64]) -> gtx.Field[[CellDim], float64]: + return a_off(Coff[1]) + +a_offset(a_off, out=a_off, offset_provider={"Coff": CellDim}) +print("result array: \n {}".format(np.asarray(a_off))) +``` + +## Defining the mesh and its connectivities +Take an unstructured mesh with numbered cells (in red) and edges (in blue). + +| ![grid_topo](connectivity_numbered_grid.svg) | +| :------------------------------------------: | +| _The mesh with the indices_ | + +```{code-cell} ipython3 +CellDim = gtx.Dimension("Cell") +EdgeDim = gtx.Dimension("Edge") +``` + +Connectivityy among mesh elements is expressed through connectivity tables. + +For example, `e2c_table` lists for each edge its adjacent rows. + +Similarly, `c2e_table` lists the edges that are neighbors to a particular cell. + +Note that if an edge is lying at the border, one entry will be filled with -1. + +```{code-cell} ipython3 +e2c_table = np.array([ + [0, -1], # edge 0 (neighbours: cell 0) + [2, -1], # edge 1 + [2, -1], # edge 2 + [3, -1], # edge 3 + [4, -1], # edge 4 + [5, -1], # edge 5 + [0, 5], # edge 6 (neighbours: cell 0, cell 5) + [0, 1], # edge 7 + [1, 2], # edge 8 + [1, 3], # edge 9 + [3, 4], # edge 10 + [4, 5] # edge 11 +]) + +c2e_table = np.array([ + [0, 6, 7], # cell 0 (neighbors: edge 0, edge 6, edge 7) + [7, 8, 9], # cell 1 + [1, 2, 8], # cell 2 + [3, 9, 10], # cell 3 + [4, 10, 11], # cell 4 + [5, 6, 11], # cell 5 +]) +``` + +#### Using connectivities in field operators + +Let's start by defining two fields: one over the cells and another one over the edges. The field over cells serves input for subsequent calculations and is therefore filled up with values, whereas the field over the edges stores the output of the calculations and is therefore left blank. + +```{code-cell} ipython3 +cell_field = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) +edge_field = gtx.as_field([EdgeDim], np.zeros((12,))) +``` + +| ![cell_values](connectivity_cell_field.svg) | +| :-----------------------------------------: | +| _Cell values_ | + ++++ + +`field_offset` is used as an argument to transform fields over one domain to another domain. + +For example, `E2C` can be used to shift a field over cells to edges with the following dimension transformation: + +[CellDim] -> CellDim(E2C) -> [EdgeDim, E2CDim] + +A field with an offset dimension is called a sparse field + +```{code-cell} ipython3 +E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) +E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) +``` + +```{code-cell} ipython3 +E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, EdgeDim, CellDim, 2) +``` + +```{code-cell} ipython3 +@gtx.field_operator +def nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: + return cell_field(E2C[0]) # 0th index to isolate edge dimension + +@gtx.program +def run_nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): + nearest_cell_to_edge(cell_field, out=edge_field) + +run_nearest_cell_to_edge(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) + +print("0th adjacent cell's value: {}".format(np.asarray(edge_field))) +``` + +Running the above snippet results in the following edge field: + +| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![grid_topo](connectivity_edge_0th_cell.svg) | +| :----------------------------------------------------: | :-------: | :------------------------------------------: | +| _Domain (edges)_ | | _Edge values_ | + ++++ + +### Using reductions on connected mesh elements + +To sum up all the cells adjacent to an edge the `neighbor_sum` builtin function can be called to operate along the `E2CDim` dimension. + +```{code-cell} ipython3 +@gtx.field_operator +def sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: + return neighbor_sum(cell_field(E2C), axis=E2CDim) + +@gtx.program +def run_sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): + sum_adjacent_cells(cell_field, out=edge_field) + +run_sum_adjacent_cells(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) + +print("sum of adjacent cells: {}".format(np.asarray(edge_field))) +``` + +For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: + +| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![cell_values](connectivity_edge_cell_sum.svg) | +| :----------------------------------------------------: | :-------: | :--------------------------------------------: | +| _Domain (edges)_ | | _Edge values_ | + ++++ + +#### Using conditionals on fields + +To filter operations such that they are performed on only certain cells instead of the whole field, the `where` builtin was developed. + +This function takes 3 input arguments: +- mask: a field of booleans or an expression evaluating to this type +- true branch: a tuple, a field, or a scalar +- false branch: a tuple, a field, of a scalar + +```{code-cell} ipython3 +mask = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0], dtype=bool)) +result = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0])) +b = 6.0 + +@gtx.field_operator +def conditional(mask: gtx.Field[[CellDim], bool], cell_field: gtx.Field[[CellDim], float64], b: float +) -> gtx.Field[[CellDim], float64]: + return where(mask, cell_field, b) + +conditional(mask, cell_field, b, out=result, offset_provider={}) +print("where return: {}".format(np.asarray(result))) +``` + +#### Using domain on fields + +Another way to filter parts of a field where to perform operations, is to use the `domain` keyword argument when calling the field operator. + +Note: domain needs both dimensions to be included with integer tuple values. + +```{code-cell} ipython3 +# @gtx.field_operator +# def add(a, b): +# return a + b + +@gtx.program +def run_add_domain(a : gtx.Field[[CellDim, KDim], float64], + b : gtx.Field[[CellDim, KDim], float64], + result : gtx.Field[[CellDim, KDim], float64]): + add(a, b, out=result, domain={CellDim: (1, 3), KDim: (1, 4)}) +``` + +```{code-cell} ipython3 +a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) +b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +run_add_domain(a, b, result, offset_provider={}) + +print("result array: \n {}".format(np.asarray(result))) +``` + +#### Scan operators + +Scan operators work in a similar fashion to iterations in Python. + +```{code-cell} ipython3 +x = np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0]) +def x_iteration(x): + for i, x_i in enumerate(x): + if i > 0: + x[i] = x[i-1] + x[i] + return x + +print("result array: \n {}".format(x_iteration(x))) +``` + +Visually, this is what `x_iteration` is doing: + +| ![scan_operator](scan_operator.png) | +| :---------------------------------: | +| _Iterative sum over K_ | + ++++ + +`scan_operators` allow for the same computations and only require a return statement for the operation, for loops and indexing are handled in the background. The return state of the previous iteration is provided as its first argument. + +This decorator takes 3 input arguments: +- `axis`: vertical axis over which operations have to be performed +- `forward`: True if order of operations is from bottom to top, False if from top to bottom +- `init`: initialized decorator value with type float or tuple thereof + +```{code-cell} ipython3 +@gtx.scan_operator(axis=KDim, forward=True, init=0.0) +def add_scan(state: float, k: float) -> float: + return state + k +``` + +```{code-cell} ipython3 +k_field = gtx.as_field([KDim], np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0])) +result = gtx.as_field([KDim], np.zeros(shape=(7,))) + +add_scan(k_field, out=result, offset_provider={}) # Note: `state` is not an input here + +print("result array: \n {}".format(np.asarray(result))) +``` + +Note: `scan_operators` can be called from `field_operators` and `programs`. Likewise, `field_operators` can be called from `scan_operators` + +```{code-cell} ipython3 + +``` diff --git a/docs/user/next/scan_operator.png b/docs/user/next/scan_operator.png new file mode 100644 index 0000000000000000000000000000000000000000..f0c1d03636b2758296da39a29251c2adc5b321d3 GIT binary patch literal 8760 zcmb7q2{_d4`>#YsMPmuAjJ1en7{e%HW{hEGY%^xji_F3dGiHn#Tb3}1BD=IuM9Ef) zNGq=*T2!o>O|GBPnF4uh5XStvIxj*;kzVEM0Z%;S1mFrh3 zC@8486L7xZT&|$7Y#n3;xcd91>w0in7Uk=PQD}Iu@somr%1)^(Ln@37;S0D5CN|iA zuS~4rJW-U?#0F<#Z5=EYTXFcop^{)>lvM;*3Lb&?LJ@}_!sl}SZDS3$wy{JYEDzQ z63(HqX|4iB2%L@){kxiIt|W>tiuljst*sDNpyyI&X`Gn*?^6ypn$HFkTASFo0>%DE zgP=)7Fp%RP&3Xf+SR(%I$p4QfgJ_752!FUWGME+XCBY$3{+2*fOM}?aqW#(8s2IFA zJeVMcW1=`b6Kf3GBNz_17m&SV;o(%8J&MTZxncd7uC|_GG6aGh7sHR?BhkJBES2W% zD+u9J#If;Vwpdvhk>pL2$|SBXESV?6%|{%AqghiVD25B(#Sclqhe&t=JR&M8o*C*H z7wcjt65u0cWIE85AA=YXMMTAhqCD6>@JM%G8#fQL=TY-J{;~N(0VsT$uz?C>V|F=RsidxZWaC zya?y&#tpR>hqFjRFA+s7kTA$pMsTc@iUi99ZApH?_-H#Zcm-$B$Nd+EJEVlQMx(g_S7Is)wzOsClqnJ5A(Qc6OF!?{$@n-)T& zuo!p>n#zwwgL|&KeONG(&asR36$5Z%aCXsnmv97!CZGcg#9`2^7!*0iFC^Z^)y10> zA7SkoN@nrc@h(1WrZpNH%x8-HI6SodQdeJB7C<~c4#BafdEjEgtWmzm2sSE`3ET(9 zRzmd)59hkE@O*bxxEtHr-PblOMCOOKx0Ct#auIgE0v92HC<(Q}ibRn}8ZsJ5Lq+p# zqC(>lWHdSk5evk_6TO1P{&rse-V`?4o{YAUcnIx${q1CQGSwc9CZl6mIJTW_gguu; zh#}MBmmC*f#E{vF2vj)6)0!eCQC*j0BmyK3Z+=b^4NSF9(X&v(8iy` zBQhe9NUj@|=?9k*C9F_)K1l+1W$`6qA75(@A%>476Jz4&{P0LJTFML|c#CbMQQ;9_ z#t3v6INFnAXf8N5TuPA$5w5}hL|Yl3YYXSXQ379qABq$v6btF`NSPPTi-U^tisaxW zY$7Lw%?`si zO57vy#4uT;5P%sI=N&@l3TQa2j6xI=Xm}bG;g07gsJZ2;V9*6Rw2;8I+WK3k7n?1&Z!J&C`{cS0cev#3>QoMxZ$w#9R5$I^ws8GKp zPa%oMMtRv#BLwhpU$KaY_6p(QL%cb0p7xPawDg7VH1aA2zgAt^lu)f9}=R%DQoG4Z8ClAbboxI2VWM++Dr!Y<~+ zu<1Ry<+)7%qlQe2i@Bt``CH00*Im3vNjGHfXB;iS82`vczq@&3ex_`G$E0j>?+p5( z)65?;JNH@=<7eK44oxl&hE66vh&wqc#7WaMZw?mEVoA?6Ocxed64795Uh7}iCQ!Fe zI**oplb;Mq>H9tM#@tdF^M&MYJYrrlYf*~BOm8$xP|;MqLNU(Etv!XqTuxPl?Z#xv z|4>!NG^A>(zF4_ZE>Jbml$dBkuB`Mh=2&cjUNk8rbQgIWjF@l9%dJ1<++FN#2%}?C za%=Y}7{UUR^lZbloj#Gn1{wm6dlZ@9_VqE~+&ysTu;II#e=icFKHMwM-eC85@%Q{W zLZ0E&_@#%ww-2tRsEzcWqckP_lFUuN5`9{npFA5jP~&q&@_yqQogckPf|+MQ$v%q< zzl;JCdJ4DppCNDPT7Jn63SoXnr+yt9s13Vu#_ME*)iKN6D2JFYC*~$qwM@OwdY1Ml z>pR959@Md{NeVfi7JdJ?T?}%M5ws!bW|=v?fn*bM@03|ttl@ydm$uJVJ;!WoPrK}= z>-;mU_)&!`^=49HSbwz!xzwS^n%jdLO8B8RP{*XG4R>a0721S)u3h)>;YqJcp%2g8 z$=+bG#lo)!2bOttE@~)ohQH<|M6M^OrPYt#N?9(Qygm~a^hXQJ8CK_koBeUQ!IR(r zB=GS3&(Chy1L{?;uf|Zm4d2EFC?dB#vY^$S(lqh#ZckM{7yM^BQuOAU_oljd;j7D4 z;MV7)Q)1Hrp;8CX;*(P@X}r7nCQaO49Sk9P_T7O2J9BLM>X&JmG{h+==5QizDjA2& zRXvT8GbAw+FI`W&<(!dzdc6NkWx}uTVL7vZ?>2<9>XoGWVAE>3Dop{yq`1twnznrx{k53pW^ zcb=c^7cIErx3RYN@!Up62jp{XGY?zNf2gK&Jm3GRGTa zNt=z_4tw87_|>0tVP{nFdu>SbF6u4wv4t<$Mv#XmcX^3!{4!1a+OFc!^X|_Bl$kHD zw}N$y4P|7_#>|%HJUQz{IlK|wj}m|#dHLj4`qz`mi{`%2LJWY-5(^;yCDqT3AXv^)aDtKsUlgo(SmOLxT7Y;Difn|H!> zrYx^7&vjT+Xt}#MernVO*h$%R!X;Rzsm|X0I+p%@P5dVq-8YSFm7;A--~Hdf%*_9& z{zTSQ4WDW4fbqSU9VZ$q%29G7e3rM_Zd^yqn>e%KYo$Bm^1s}ekFcFoSkn^eJvTj3 zd+N4o8$Nse%cuwGP%n3YCG7KEQ)2^%#-7#_Lm!s;Eq{3_bj|(J9rV|f;X(y0FGVNG zVA+X`R9dEN&m+*K^{e*C06i!>Hn{~@?U<^^vxb0Af3Dd@)%bQ>H;4*X*Msi^(5+J6 z=6N|t04eV%#~!E#+td{QbBtP-IQ=4YY!TL?q>cH!xTkVno>x5DR8g)Yf7)(X|LdZ< zQEbxzbW2&?DYeavZGW`pe%;sMj!-O@x$F#G0Cu{?t8C{Ov!XdqW^KSc%&u9xSy{H! z8*zSPxnh|U;omkv_W;KaX#e39e1Cf2$PTQ0M zFUDm4jKa02FBT`n)&r!eKd|HUE3LV= z!?r{5ep@znR1;gB^$uOs+ocOp#+YIjKDe)jm4y#mKKg}gy1n05Q#BytLyfy8b{3)` zG&LL~m5yOHru6OSH5`M;FK)c5*u^<*oCaT|!4w@Db{d+&$?dRz&%7-itvHjUscO5Q zMJHNAzM;DQMBv(YQ2q{sm_frbIF%mxF0+Oj%=5=|NmssQJ$sydA=5TJqai%wf4$~C z!>OAF-e3ylHwrzKI`(ZOkuZPx1T_W@PK{nHiw<|~IFd2CN|*jP@NmG?re~(#pIbFZ zHe4@3$-h58d~UDi_AU~)Q0vHpl`zpGagh#W`9(9*l(_y^_Q&r@warHedgs>_5P*6D z>s+rNcHRPADZgQ10G)GE)&zu+(ifwwgc&;PY#7|XHDme^`JLK`6YK~JFvF=!ou9^Q z(c{iZlh4cdwD4q^^1;jkLbo5>APd=a)vB7K1689UWg`KxNoTlqr1#N4f7m) zZ^3>PUtjq3@#MLD*o)-!djT&l?S3)N%6Z)_S*@=%a{ASks8DmSRmS&mI%ak&ET+p> z$r}O7s?mGDt=YBUoU?1O)kDiJdi*){$n03{d8+v+;XU9IzfXTj4{1R1s2X?Iz8jMx zFLvK^h#0gl56v_yQ`BoYBl+at*=Bk9S8R3 z?Q~3>n|hgsQiP2wzYR_upP;#vg_SfI{NA=r#Xj9(+5SFVeXV=?`KN=f#pcyf`|=C~ z&l)x*{{Hz)`t_}g@OgzFa;a2H=i?{WJtbTQ2$Cp;D$=y!SimCVI^ru=&1Q zbp3)RgKIw??7)`AOy7mG(A<}^4$A& z`UN~_89EC8qRVlr58G3_f*JQt1bL{%`R!))*(#uo=XC`EnVERy20S%Au6C97mQ(7c zYU`1Gn+g7FDF(>Lbj|3BOXNtIgQtZi{;@{-AwvF|BWA{4&n)o`Ip1AUY^oy=iZFpuu~k6L;;qzM zQXnHx$fFc7Al@E^Ml$>7rIo97BUkJz>h8`nYDVnAX0UfwY>ptgy;OOtn5MEOdL$eM zaO^yoNRk?K3fJ}e>dj~qi_EL)_(93PWzFk{KR(a@q=S;To?G`hv%89Tq(AbFP?cw# zbv4I*cA>J_Y~jn*$$mET`5mURigt-$_(o$vM|$7g+k;Qlw#Xlic5biLUNs|s zH+A^;se;`$PgS%buaj&R!p)palX%;~^lh^xv&#$`wnEbjfi0x;ZC9=cq+7!~!##{O zSIWP!EXQQ|dYEsBJW~CQKH)-Zzh$7J*_*pg0r2l_CQwsAgNZd{0aa1xf>RcC*Kh+ z>v(f^jbti4d0oc!Tg3w`n!@&c$V#S zUuP3SeTL7tV-OB#t(fNtY=p$n*RsGI-t6eFtY5o&Ju+i%X3`7b(`Ps0bw@f>6&_sp zq>gD*WbnxRtCfornYFfavG@5v$7U0( zQ#gY(3T$#m^!TZK6QcPlgfY2A*M2!HbGXQY?rZY=>82BY+iyWl>vA$}?GEW~Hx+b2 zH|3c>KbNR97mRI{y!Vi8iJPBj;osV~?A&Zk;Ptk-kUip!+Y11b=cGpqRWvoqisfey z9-96g-DK9}@Tf6&U$b6nj%4l zA)i7tdbWcw2i@NLz`S+Sf#g4zO%-@KJt`AN8+R6kJS?kE81E^zBwpV;Lo+jJb!bnJ9tp!kCA!3m_VH*USTzPIRw!%otkGwhkRC#A`!p<{vle8&`c5o?$|tZkt6Q@JDe z;4qt|xWz6JB8Lrhc6F*bo4zzg$n&ocb+UJ8<=!4M%7hgz$Q6UPEja#t1>i4Uomshx zFue%(8_^*zpbN);D+Vu|TXnU9GpP$nJkaN)@ItGntS`m@U=to}&^hlN)T);~QWp(7 zR%_~>aI&vx^JV;9Wf(o%zUf(9wxjx=glZxWTjG8CFkJ5&!hLGpK_@36yx_;rUL)^JJ4lX?gNNlxg&NfAohkP96-OpRKT*7zL~ALg#U)|E7YA6NW->v zAX)Ud+S7CUMooyy_f8z8diS+s&fTTXAemHNJ`C(V^JLOoCJ1M7SEDCxv*_(l`+K$ym6jaEOD-6=iduIkTGBWNNsk0b1e>g5SC)7dFJN# z9*b>jm#-|JEwL|OF@Lkd_0Z#yy9GrzUxU!J|7Q5J2J%sx(CEj0oo&nZL;W*yGSwZU z4{gO^Di1ZU%s1_bUYQ)_u@9!LueC|MZ64k)d{m)){o9=l<_8epb=kW{kp+7z!$BUj z7uiHKax#s!ot_-h?|Gg)(yh@`V3sq7h z!?J*NUtx`?vr0?c@km)^G2}wwexcF=QDNwLBfF?dZ=14%boA7#q7#QH76wkHp4~?s zWaiCb1Ll4O8b)~Uhh>gi1KOmXZEfcwoR?;3Ennee%E*u0(thz?@FGw&#t zJM&VFX!<&G0GuhhBJ<)^-ic1T+JNfFCsCY-Z|Xj%d$uHwO6y+T>WX(@&D-w`vZ^n& ze|_W;y0sc?E!3jDCJYNG^4`&c_x3gW6k!FH6Cc2+)oO?6z)Dj3EG{?AuLX1`WL&FR zvgE+Vs@LZM+`|4G_H$h_^MjTXNlSFMLFc^rN`M1Ecjr&7nxy|%x;sAbzti2eNQe0_ z^&MZ0H9JrhUMHQVzV~i_67|sWqnKbv*M+>$Y&$F`Bpl?>`9;N+yL~|kBVf&AdwF+{ zf5RW>msg_9vNxjhL2mRuUQENL`A^n$yySZjC zXj_>sa_WAU@zd(Ox$}Y7<9qPx7J&~=pxL(WKR(_u7`z&D{=1dSiod_SzG`v6EG5_0 zjyGPLB`v%6v%S;m*n8%3qx-;JuX|FU_O5U%rQgN4o;3`~sTPh?hWt1E`oPw-(#n6E zv)S=j@irz#Hr9_gW4Hk`zWg0ObUS1qBdgWzc|*XC_p^m2tMYxWSS-8YH42tJD_FaY zk^7b#JNq@5YbL0<0Lwx`XOHYrElyM2do1B6qA};qMZNGRXVij!+Aw$S7x!2D^mNI0 zE3Q{f1VQ`54Wv8}>r|?nfJphHC%H2(L`%ojI-72lC}9$Gm(s(q$rux?uMQq#P)n^_8cv44Mn&&E=am%*CW zI8VzjnoVsYr1YtmPd?Gut8jNVL^imwl-2uLr1lN*jh_5q;C`dc3Btl{zapeMcIHBfX~CXk0%%g+7KD@TV2@_Tdg<}(Y1uL|30k57ts<(J5O8H zu~TyEoAT=Z827<_a(-hrWlhk^-`c;$e?0kP??JT zkYwzuZx?YwB2(PA0VQ~wg{bB`*HwO^C50VM-v(J?iQXgI%_`GZWS#XoNs9>RezWG` z`pT@ow2h!0R>MmLzNyhieT;^{>op6OUnj1J-zi!MRFImOUP~O+)Rn59)d1_<*|Ftp z((kz*WS4pP)4Fm{5b_eCi7M=!m93PN>^Itac>lD%! zv#}Yu*%>@F$>U1*ZS#=u+^rf08BnLAEQHSDi)HUxwkESav6MPMj#cw}s7>~{A7v5q zWe~}}v%1-9M4JumkgvIt)fi{#eQKXWNy&v17hOuSZjM&k#JtndoQ90{e*SG2G3foZ z9i#3UxG$fQ<(#qGW~~K{I&^rcwAm8zdJCwaortUo+wj)0+;Nv|_IvnmQYbTeld1;w zEm!54)nNLir4nCZ*SAJ?*XZLaqTmp8TJ&n5e)o}o^2N9IqBg*cxHewy3nNaQUx%&y z=~lgT_vil1_bo$7rwv^XR{O2#MV`H21iR-4{VD!a^Z3)RsjJ>uSSM>R%-r-C;oqqzfwYcEiPy?mVpJ?ZGbD z(*s%%+nW$oD&W?IlGT|YOuXIIxke$9oC2@BHd_LbCvAoPY`Za7hpJBnMzOuDz6&KSsG@pDV9n} zASq#T`XtZl$5pvrqlY)9)hmH00+CU0wE<*~u6EPM2fj=_6BTvNsQz@#_=8ebB3|1s z+z#!z#QuLw6cx)u#h;?L0g7#_i=}x1U*VWWY+LaLU1n`1-dffgt9eygm)aH1z)Uv23bz1w{F!UwNnHtOyjgQ>8j>>IAj zEmRq9_m0^fI)>zT(AI2%$OY@NT}B5MF2Z&uE!YD_$$NS}$P1)Lw1cX0ftVD1D;2)^ zM{hdRUF&oIxzo5Ed#CF>HUJmQ$#|{1sw4K^_6?Y;IKtLLN@I$!4I2wKT`(2=;N`IbopEbHB{- z_iEZ5b1qnQX&YR`ft2l@Z|G3NjJHAj?9=w2*&6SJ=+v9gt-K+{qUSky-wa@LH< Tz9PZDf+)DVdg2-|ds6=g`@|Ku literal 0 HcmV?d00001 diff --git a/docs/user/next/simple_offset.png b/docs/user/next/simple_offset.png new file mode 100644 index 0000000000000000000000000000000000000000..660abe87642151d390abb723a0c40e5ee9c22e00 GIT binary patch literal 10292 zcma)i2UJtr@_xXpAS#Lq8mSQh=_G*!1W^(QC834TODLg*1OlN25Ks^sXcQ5Gf(79! z2q=051*E7Lny5&XB2rYONEd1U9rWG%-dn%5-hX8!=bU|Zojvo-H?wob%EEY?&^{pu z1hUQ4#E=Mq@Ik>BCb$Wd;C0@2fD0d!XpDgry_Wb2fo!q}HF5~0NBa2FydiQ(?5{gH zgoa-bGgJ;~D2G6JFc@kSe-Gaf4?0sV&^r_qf%o(viocJ)H|19yga!hs3Wux0Q6xBA z4r!o;1V5TcHIxRz=~sOZU+=)*9hwJ4`O|0~atNH3ng;0Vpt*;ae_&`3(@zeG0nes^ zq24rb1;yak(gyrEfG=Ew4A&x~6~V0mgF*AQ_x8m5gW2G;kZ3g|3KSp2n~=;2atJKA zr}+nXgNw1ZS3nT23Dz$phz?3{aJY_|2JZt3Z9IHDLj3=&3*O4SLp^+dt%~XtVc{4W zhNpT)nt2BV2HH4k(SNNb+&hHn9~Ag|@d!1z8feLD9Li#N|0<<;hx>bh0TFUYBOut{ z5(HHOJ-|HHC2IwQq6+`j(Eqn2qtNE|XnHV%OtbgX)<$djN16rw>ICf+LiA)XBXFi3 zK7PzlytzHr#=+2u5ekPhHA0ACKmr&|x?v6Q@Z$|XD zh-4v*0wXoy)<`qFHPwsZNHRnjA_;aDIAfp~7-J188eFIZuw^(V&CZDBi43=oG|)tl z;HGwF7)!J-9HU{ZK_uWEiNRDIPqdl0E!@gHjAq5#OA8}3l0vnmc?2`Oy}YcnoiuI2 zHMD$?NK|Mz(<4gD!N@0ABizu;$;6lHW9SD)(zFfIw)Z8HJk5hG;Ix1MYl4F#&K7h< z471k3qbMi}4vxlQLp*6=rVa>AqhMPWJ{ZsNWQ1$mVS`aVpeDiC){$Zu#WD|u(@18- zD5Oy!2IJ&Qcf|XVtVm$`*hs9Ct%+}AUTfW%qZU?Mzi9W)FGgh-4B9O;N7 zSQD(7RzB96p}vu!7(WA`ShxVIC?5w_qyfeQWr)zgn;4lBtWB&gql3m6Z*PA`dLYv(&<JhW5e&4WFNxsG3PAW9n1in9aGa+nHJBbuhof!1 zeXTHbB&d&#Y? zorfulrD2N1vm9`wC3@Hjr1Q$z`k2MMhl$dBm4faHbVC}W(Xo9hmPbk8~mVxs! zv%v;nsZI_~lyHsUU>&$9N@F!OW-P~yQ1q}rL{fX(6H;Y=gCZ2*Z5 zo{gyXREnb(uZ<}yl0*t65qSm~7@>s;^e1T9S~_a`*hJWcFif;8!w`{nRy1qP0Lu^y z4LFmxzi2HjLyM?TxQ;fF8Xgsfqz3CyyiIioek@B?co@Th7G-5`8MuWtDVa@^Y@*oXid z%}kHo8n$<_*0KioPD=+omNk$Q4DeLUkf0Zac@8GW6Zy>M-&VsdqI>)_MqzT&kt zovInVNu7m0804qNn%0Ze-OaDrtC@fKFbGXq6)P(%;e!k=TZk;k8JjYRdl>NQiqviy zir#6ziY=JBGJ^+}7uaWdJ$&+~bZIds6O)4z^avL?$pRu>M&TPlo9>hEF0nViyle8= zlxs+(HJ1o-ejG+sVlZ`&4!g-4b0aTa%%a%Fv+MI+w8sTGjEb`%d5MI}kEI?R#K^~9 zFfV4a;q(jX^&$v23Pr(DREAQbw8ughBQH^W>zX2LS8kz7`lnK@Yoa?51CeQAc}?K` zt8lm7--S4my5a796|9WI(I|OB#z)^1F$uTrjDy%WtDKZK`9>GB9Uu9L96EI9O;^{M z`T215v(v7fRA{9^(rSL?-pi7AU=NQig{wH0{PC^5@cAxFJ!_GRY`7?1iTj3Z~I zpEbL40|lX$3R}wS^WyEfdb4k^89Ge)r%#`1gpIzZA@c*$T6a|mYX=lyv&aI{>NBYK zXVBa~hKgZ}I-#z@gyU-3qmk~i`P-^yJjQBNn{RG@O$|{~cCFAM3 zM_9j5YMNQv&|5G+aC8mX)h`Q6O)nXkTUn=7i ze zL_v(-hhE*|CkUuk^%)E4Y(+7Ice0PZ;K}bmndX*4{0DY686~n62OljU^M>j_Wz17MOu^UFdN+IMQAE}D#zj8h0-{n&WoI`upo1;enU{ucNFfq6S z`AM~@38F3E6?_WzI2%TJBy{J))$JIMCkL8)-k;dmew-Coc?GWw$xr)k)wpXrc6(FC{x5#of6NV7Y+>%zg1c3dMvz5SP-f^OsromacO`e zSc?ToP4nX)9#HCc-g7P$w$z(4G}?rz+d_5@+h18yfPnt|(2(?QEwGRL4?bL;7#}wx zlcixcgdS0T=uhWl&RJ^`Td@|@p}ur{8mj2ioN2^8xL@AL#((jK5C*$y)^Qb6dr+7s zA!f;q ztjsy79sUo5Af<(vdSP&?Ii)fxU1a}}JTaB~!c9xxL=mhEEccG>oLJ&0NAru!Nw7;^ z8=AndL%}1vTYU$HiBgSbd|*B^8F|C{o1uleiYKgeNq*zG_Te)I2YXUCH($Rs63|EX zK^UDpu@)={UCT`Bl!fJlrErq7%JEL>;(u8Vs5i<@)DMZASk{AX&(ExSx4BWye`Lg< z>oFK}+vV9gk!G%w%BRPx>R$Do%vE{+slw)C#q>J?$f^E3;x5ppF zXh~1`c#=_yOYozu8-q`vLB^-xOCvy0^`6onDbE;~aPRDY1&kxPp`R%J&_&5M z7MZXM;iOy!?<5)gvK3A;>ar(2h}Y9Vk`!!LQG~v)1=p{QBHv0fxPtHgcsA!SENAW_ z`>hhtuY)S-w>ejx?pGgLmu>ZarRSmRvaMi?|!KGlxWb#Ydg_+T|0=P%*A8Edc z$!7W;4x))k%fr37bvNC41_r$MeZRN+Rop7}qI(z{!YBJuf{uN2W>iJQJ66S3j;g>Y zr0KgbFD@A)FTmYBP`nS}d;J*dX?gk7>e8e%+r9tUu^&SUM;z;^;S)V)F3KAF=-UW# zLK@|;Jagwm82td|-t&YRJYBCa7(ia?=Gz=Qb&l*l`{ixgT0}MUQ0PHes^jqxbnMsS zxIw*zJ4)$izuRx)MAf*=BYa;==uLBF)7D}?>-JmwHxTk%0C%wcA@&QS9ejS=k2#aBfkA)USD$YPF={3o~W7Y zAYTxX#ocMyimAJrsKnhnozAZRy|$tH)Tu$~wX*4;PcP2&z21t+%f0+=+F&D;2SFc* zWn5v6$zmZy^uk0TMRlU4JfP^wh*#eeqHoQKjTPmWElK#|F-`(|b;Rr1XSe31f)mNn z;syl@Q4!-K$BbqCEFlJVFBIFED{t5eCUYAOH8wLysBj1xDOR1fE{kaJMQKlkl(_h$= z?QTx`8Nfrx%en^n=~F+iLy^XXAY`KdsQSWP zw<_1e)qnM;2J2t)`Rz~T-yX>GYFYwN136pH@?kv*mxQe56;17fFd-EeZ?y(N%J`~n zAgg<;;~?5ws>%F6f*@<8@o%vZblcq0c8Jub@3gsv8Ay-wN>_7`^;y>K!Ouj_$(>kYjZ6;*KT$RuQWWBieOwRtH1-oCSqAujR*B_a#s z%%VgNpF2$pzZd|yz|ZI>mP%B=Y#V!(W*TRun&t-i4v8OVBjkx8&>J$wo{33RLxEd^ zFucgya2TBO&HaLFlGTcB&0Th_CiBiNX3wD2@2QT%W(5ib-xs4H%K5o1WuvQ2Sr9w^ zz5NhGeTJ|}k!3$8`R^_7vvYejA2G99cdfom{d2f@d0-XgM*0Ww;H86cTR#3AJs!Ay z#^IT6Nv3J7&3#bZFD~tHn{LCbI|Rsex&bTE|=W8ngwp|K-|n)HfrwRZATqv zxutDej~rxmW3>e&9m04M-17R7t?Ymi#MllU(>sHlXn&Z%j;^coHWXW+2y$*&czqZI zj52Qb@Fnq|5qxEuOLb3KWx7T)1#g1mLb}d2Mp+c-Oh`fEY>!R6tW2vuygJ#SE^u#% zPq)5q@-D<}5cr6T>{{{_d{64zPll3-8_cUKi@Ar_rbXga&HJ25ORMtsn6}Bmt?#SK z5BH8uygNFsxqZ-dBRVd3*A@J`Ci8r}$}0{qfRMX)j+%i;QU5;yf?{Ud|t!=rUdANhr=#-D8S>UmIP-klDV z0QBvHcsG~V2A>^lUiv1N+VwO{St@QhxMF#_QG4wEbH)hA_tRPBqH69%*|tra_@Q%$ zkfAbwR|1IjVfluv>4w(@Q|*XT4?YiN}voP1lb|lh=NB z*tUxxlonkEn+*J#{$Yp(Cz`SkM}LeDsokvb@v==%j{Fk2v@c3oOvNF-t-yZp#Fide zxh`@dA$&Yv6)A8N&O;oYR~MyFJ9>J1Q-G)0bY!WkYr6g+%jFW z?RwuKMdcHRTteR(30z+Ol5>GE_Tbg0gzfs(MXu$HFE=L&k94^u5?I5R;Bk&>EVcBr z)Ak!VLu&8Z?8zy)_;e}u$7Xc@3qXGuJ;ASPa=SA3z;djM-=SvGUT0Ljo3>JQaF@(j zu-^!cILBQ6G1Xb(y(RG3qEF1~%=K4!=2!32x)3}{w9p$N9r*O_s`BY~jwamOtpzo2 zISFqv0RQ_;$U7Kk%MK2RU771pj<|#h4DHvSzS>cHa*J|Hy+p|ArLvwebxdAGNPPzI zwt84ly^S5bD=b=LvSP+!xc{QDl%rzWqaX#7*#neJ3@*yMlj9NPX4vSC4p_P0p9p843$G&BSae7jxMoa5o1Fuy!Rd81P-xjzTYAeC{UqS2F z=l<(%`u~nOrp3Z-HvS?jJY1k+)gH}{ta!`sax%%iUi>@a$hAK>CdMQFwZaIIo1XdZpW8oyGvK@J+St^3F|YKsszNq{w~5-{pGsVy@`*<|#TDd6&tU*Vou@ z^Xtn2c_-P%4-*HiWGI5rLjKWtFi{6JQzZH?_v2x}5W}9l@9)3x^4y;CmdND{^ezR9 z7hnIpUNoJ|`J0J@sF_PyC%g-N(c1x=r9B^3cYeGT>CO(OQ#$tM{+1E(8y!+21NJ z7k?!kS7AIpAMx@Ri=Kpymt#eLjr}38d|O2SF?4p!v?&irJ4=J4UuzF;C2$ZMv8RM> zx0rD+JF7vb(}#(|!$gtQd>FIKJN@RvhcjM{5jR*rzy1l2n)MR}Ih(YY(G)UZzKvL( z3~n_1>VJ!R#g3~edhPbu?;BzO`+w11==^{&fGN|wS6NvJi+`5^4AQz_-~kBFhty%I zSb3YoDk@jgY2h*b1EUH4inKb^pkUr*_|G(WvA-x~exmo>cKsiaQ_`{8xRiZiZ(=O< z7bz8U?arN*(P&)q-iwPN1n%Y1DSRCA6#e2hq;I~_A^_Ha;>x5}3*Bw4e_Fnp9&X$& zxHMW&>c}!ot#X!_*7@I11}z=Q&X$T9RwmbEY6S>i`WW56r5505;DhB{fin-e>Xz^U zpqG%bd9esUc+}>msQ_<8qFZ+R{8gU220E%<%!>rvMo0RKXB#HuQtJbB6lGF7c4H<1 znMfp*!EoHfZ&@uI5O!*@a=o7L)Ocs3oT|mXQU%KUryxx*l{!U|0*R;LOAY-p`#>HD zgA+CWUj}nYyed0!YlS!l7a+6nUulx0i1f$^|LTXQ|4A(kWG8DfZMKLx2#l-@XA|5- z(mbB>(gF1!M}so|PG)`W6$6oVBH+8ZnZ(+UBP)eR^ZxL7C;WA!zN1FSq*36fc*GE> z%^{TW9ZVwR6z_OkLJJ%#$P$!I_1KZ0qL#B^orO87-K_=DeiJbHO>a9vA`T|o+;c+X z<=_EcE+R7YgP;9+F2d;izj6^iPe}R8KAS*l!;ax~|96@PSh#E&@aNj^4#@1?~5w9ob3e*^7yVJ@xznI!&-N=-rWv$F%wPYoikFK#V*I;xZ` zp+g)8u(jGNC(ARwuqKeO+lOE+74)q$q}PYOsm#vJNe(r4KJa=FyAmZ4`(r?R>p(8Mw&QZ#>ZcR0DkKG^X-Fu(UUJh`j6zh z8J2U&avO+ocOrC-<_Q9K2t@}FJ2I$YE44L|)>^cD2o4 z7W(6)f$i1d)hk2Uf#lWc^asGfZevJ7ZDcG6G<3flbkPDM&$OTl-z-O|raxdJ|(EDp|Z?E?1 zk3r=lYhR3y)LvG!(Kz-qRE7NP8rxd&M&PwYGv6k1;Br(F!Fn}cW~O$F`p0X`*9Oti zx`%Xix0`(bMqb;eaE-wq6%}>gw7yHEOog1 z9Csab#<`o8y`Ky{aZzd@2eGFxmp5hYKv2DH?GMd;Q9WcbP^`&pjBaTE$@8gRs+8rH zSHSQVdIUKVNa&hJX_aw=2#hjKM6bo>4r5h7#u^6P)){-{z#uHUt7Q;Sl!G zH?=-kcfs3?|1mj^r`NB9=z0=^=l_#Na>AeS{D010ow3pDxS`2>IC;7snhAKuR6yCw zz9!)QbA*1=5IkK4#M8V`>)qTHs@diQ0C{Xhn3CQ$#oy)GlygPvl?Uo1{>lFh_2%XY znKcT45t;yRXiIYO8*g4}B*;Hvn;jwl9v5G1)v5X`t+3BtpOp8q(e>*k`|*~-Zljfj z$B!Sc`ew88O6tH}!;yuMs)y4{x)WvCPwKE%Z6Y{_*vZ5m+AQf3D8NvN?=+BAlJ5TS zp(8vzd_v(|b4M07Y3s#EIj7=gxsgK!G6;@~sKy)C&gx4?Q5vpOyKmcvw;E{EyC3<2 zW2)DCiZcb;mEF7hhg>@!e7HhE2WG*&1FerHi7$<|USM0}zZ7qs%8$ydu`z zANAxiKjPNh&n@ZhQTr*$PQ&pJA0I-e=Xcb;`f~Z9c|lcG6)Lr{-Sm7R@iK_?R9|&v zj%Bjp{*5(V$r}j#oS(E;KW3dRTMkir1=-*8rJ$#QzOhUL{}rR+~(GF&F?WX z&*pi`W4XI?9(Gt`o`}m;mRuFC1u?6As_bafqpF1AiA$*Ld-~D4Hd2VRW$LlyzYfE8 zop~sBU0y70dwXN&C3o@)Hda4d`#e8F5fwWVrrgs9^5T^< zhi3lsNDTSh@1Irsyn)ZkClbC0M`g#2pCK-JU>yxwL!#Qzu{Jz`7Us%dlsw8c^tn?! ztljCd{p0COVJ2V7KXdO`+;i+mBgDl?U~lP{?R+uFmkK}6GzfefCJJ#+w~HT}jfpwl zMA>xn>ju>>a0s?y?D2G~JgblZjuX;)LwDpIOy6tPMA^H6e`AcEbn4b$!W9~lynHI- z+3D|M;hk_rCFy?&r1OeT=3!W>DGD{67eANM4|YeBIg{rqBPoJ7zE^Zyvk)u)vQ8cL z{4mF_#in4(uD>^?Bi~Wy^*IjV{3qf0B5-WcYLIcuW#3pV+PTUfkRI{yP8=U+w|JFD zd0PM2X3p{>$=So_{$7TzV9p6PG%@e0Q^~(ZMdrrwfpm<}4i0wfjrhOiaprSWjnHY( z(5O7P4*cWnH1EFvzhvQjUqCF$OWv#Iq{(NH23S`v;%!uZx_39Z&y~33Ock&`3Z@@) z*Hc{T8XHct_+s#wr7KDEjmG+ZS(tc1)n1dYlIEUpQcZS9^WXO_Xiwn^ zHH)xexYy65Z^QjKrK1}&mEH*M@JB*trgtJnR6qS8YkOn0y!2KA{*yRtm;WD3#}QGJ zk_^8A3^(I~frW)$=v2@aV(mR@z)@XdfbULXHiwtS>n`Af8@ Ubnz(o8z98g$ilD)<9_1*0a&l~Y5)KL literal 0 HcmV?d00001 From 8039a900dbfb57526fe6dc2df5a4404bfa23965f Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 16 Nov 2023 13:35:49 +0100 Subject: [PATCH 33/67] reverted accidental commit --- docs/user/next/presentation_slides.md | 411 -------------------------- docs/user/next/scan_operator.png | Bin 8760 -> 0 bytes docs/user/next/simple_offset.png | Bin 10292 -> 0 bytes 3 files changed, 411 deletions(-) delete mode 100644 docs/user/next/presentation_slides.md delete mode 100644 docs/user/next/scan_operator.png delete mode 100644 docs/user/next/simple_offset.png diff --git a/docs/user/next/presentation_slides.md b/docs/user/next/presentation_slides.md deleted file mode 100644 index 87cd2b7787..0000000000 --- a/docs/user/next/presentation_slides.md +++ /dev/null @@ -1,411 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.15.2 -kernelspec: - display_name: Python 3 (ipykernel) - language: python - name: python3 ---- - -# GT4Py workshop - -+++ - -## GT4Py: GridTools for Python - -GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. - -GT4Py is part of the GridTools framework: a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. - -**NOTE:** The `gt4py.next` subpackage contains a new and currently experimental version of GT4Py. - -## Description - -GT4Py is a Python library for expressing computational motifs as found in weather and climate applications. - -These computations are expressed in a domain specific language (GTScript) which is translated to high-performance implementations for CPUs and GPUs. - -The DSL expresses computations on a 3-dimensional Cartesian grid. The horizontal axes are always computed in parallel, while the vertical can be iterated in sequential, forward or backward, order. - -In addition, GT4Py provides functions to allocate arrays with memory layout suited for a particular backend. - -The following backends are supported: - -- `numpy`: Pure-Python backend -- `gt:cpu_ifirst`: GridTools C++ CPU backend using `I`-first data ordering -- `gt:cpu_kfirst`: GridTools C++ CPU backend using `K`-first data ordering -- `gt:gpu`: GridTools backend for CUDA -- `cuda`: CUDA backend minimally using utilities from GridTools -- `dace:cpu`: Dace code-generated CPU backend -- `dace:gpu`: Dace code-generated GPU backend - -+++ - -## Installation - -You can install the library directly from GitHub using pip: - -```{raw-cell} -pip install --upgrade git+https://github.com/gridtools/gt4py.git -``` - -```{code-cell} ipython3 -import warnings -warnings.filterwarnings('ignore') -``` - -```{code-cell} ipython3 -import numpy as np -import gt4py.next as gtx -from gt4py.next import float64, neighbor_sum, where -from gt4py.next.common import DimensionKind -``` - -## Key concepts and application structure - -- [Fields](#Fields), -- [Field operators](#Field-operators), and -- [Programs](#Programs). - -+++ - -### Fields -Fields are **multi-dimensional array** defined over a set of dimensions and a dtype: `gtx.Field[[dimensions], dtype]` - -The `as_field` builtin is used to define fields - -```{code-cell} ipython3 -CellDim = gtx.Dimension("Cell") -KDim = gtx.Dimension("K", kind=DimensionKind.VERTICAL) -grid_shape = (5, 6) -a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) -b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) - -print("a definition: \n {}".format(a)) -print("a array: \n {}".format(np.asarray(a))) -print("b array: \n {}".format(np.asarray(b))) -``` - -### Field operators - -Field operators perform operations on a set of fields, i.e. elementwise addition or reduction along a dimension. - -They are written as Python functions by using the `@field_operator` decorator. - -```{code-cell} ipython3 -@gtx.field_operator -def add(a: gtx.Field[[CellDim, KDim], float64], - b: gtx.Field[[CellDim, KDim], float64]) -> gtx.Field[[CellDim, KDim], float64]: - return a + b -``` - -Direct calls to field operators require two additional arguments: -- `out`: a field to write the return value to -- `offset_provider`: empty dict for now, explanation will follow - -```{code-cell} ipython3 -result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) -add(a, b, out=result, offset_provider={}) - -print("result array \n {}".format(np.asarray(result))) -``` - -### Programs - -+++ - -Programs are used to call field operators to mutate their arguments. - -They are written as Python functions by using the `@program` decorator. - -This example below calls the `add` field operator twice: - -```{code-cell} ipython3 -# @gtx.field_operator -# def add(a, b): -# return a + b - -@gtx.program -def run_add(a : gtx.Field[[CellDim, KDim], float64], - b : gtx.Field[[CellDim, KDim], float64], - result : gtx.Field[[CellDim, KDim], float64]): - add(a, b, out=result) # 2.0 + 3.0 = 5.0 - add(b, result, out=result) # 5.0 + 3.0 = 8.0 -``` - -```{code-cell} ipython3 -result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) -run_add(a, b, result, offset_provider={}) - -print("result array: \n {}".format(np.asarray(result))) -``` - -The fields in the subsequent code snippets are 1-dimensional, either over the cells or over the edges. The corresponding named dimensions are thus the following: - -+++ - -### Offsets -Fields can be offset by a predefined number of indices. - -Take an array with values ranging from 0 to 5: - -```{code-cell} ipython3 -a_off = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) - -print("a_off array: \n {}".format(np.asarray(a_off))) -``` - -Visually, offsetting this field by 1 would result in the following: - -| ![Coff](simple_offset.png) | -| :------------------------: | -| _CellDim Offset (Coff)_ | - -+++ - -Fields can be offeset by a predefined number of indices. - -Take an array with values ranging from 0 to 5: - -```{code-cell} ipython3 -Coff = gtx.FieldOffset("Coff", source=CellDim, target=(CellDim,)) - -@gtx.field_operator -def a_offset(a_off: gtx.Field[[CellDim], float64]) -> gtx.Field[[CellDim], float64]: - return a_off(Coff[1]) - -a_offset(a_off, out=a_off, offset_provider={"Coff": CellDim}) -print("result array: \n {}".format(np.asarray(a_off))) -``` - -## Defining the mesh and its connectivities -Take an unstructured mesh with numbered cells (in red) and edges (in blue). - -| ![grid_topo](connectivity_numbered_grid.svg) | -| :------------------------------------------: | -| _The mesh with the indices_ | - -```{code-cell} ipython3 -CellDim = gtx.Dimension("Cell") -EdgeDim = gtx.Dimension("Edge") -``` - -Connectivityy among mesh elements is expressed through connectivity tables. - -For example, `e2c_table` lists for each edge its adjacent rows. - -Similarly, `c2e_table` lists the edges that are neighbors to a particular cell. - -Note that if an edge is lying at the border, one entry will be filled with -1. - -```{code-cell} ipython3 -e2c_table = np.array([ - [0, -1], # edge 0 (neighbours: cell 0) - [2, -1], # edge 1 - [2, -1], # edge 2 - [3, -1], # edge 3 - [4, -1], # edge 4 - [5, -1], # edge 5 - [0, 5], # edge 6 (neighbours: cell 0, cell 5) - [0, 1], # edge 7 - [1, 2], # edge 8 - [1, 3], # edge 9 - [3, 4], # edge 10 - [4, 5] # edge 11 -]) - -c2e_table = np.array([ - [0, 6, 7], # cell 0 (neighbors: edge 0, edge 6, edge 7) - [7, 8, 9], # cell 1 - [1, 2, 8], # cell 2 - [3, 9, 10], # cell 3 - [4, 10, 11], # cell 4 - [5, 6, 11], # cell 5 -]) -``` - -#### Using connectivities in field operators - -Let's start by defining two fields: one over the cells and another one over the edges. The field over cells serves input for subsequent calculations and is therefore filled up with values, whereas the field over the edges stores the output of the calculations and is therefore left blank. - -```{code-cell} ipython3 -cell_field = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) -edge_field = gtx.as_field([EdgeDim], np.zeros((12,))) -``` - -| ![cell_values](connectivity_cell_field.svg) | -| :-----------------------------------------: | -| _Cell values_ | - -+++ - -`field_offset` is used as an argument to transform fields over one domain to another domain. - -For example, `E2C` can be used to shift a field over cells to edges with the following dimension transformation: - -[CellDim] -> CellDim(E2C) -> [EdgeDim, E2CDim] - -A field with an offset dimension is called a sparse field - -```{code-cell} ipython3 -E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) -E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) -``` - -```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, EdgeDim, CellDim, 2) -``` - -```{code-cell} ipython3 -@gtx.field_operator -def nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: - return cell_field(E2C[0]) # 0th index to isolate edge dimension - -@gtx.program -def run_nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): - nearest_cell_to_edge(cell_field, out=edge_field) - -run_nearest_cell_to_edge(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) - -print("0th adjacent cell's value: {}".format(np.asarray(edge_field))) -``` - -Running the above snippet results in the following edge field: - -| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![grid_topo](connectivity_edge_0th_cell.svg) | -| :----------------------------------------------------: | :-------: | :------------------------------------------: | -| _Domain (edges)_ | | _Edge values_ | - -+++ - -### Using reductions on connected mesh elements - -To sum up all the cells adjacent to an edge the `neighbor_sum` builtin function can be called to operate along the `E2CDim` dimension. - -```{code-cell} ipython3 -@gtx.field_operator -def sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: - return neighbor_sum(cell_field(E2C), axis=E2CDim) - -@gtx.program -def run_sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): - sum_adjacent_cells(cell_field, out=edge_field) - -run_sum_adjacent_cells(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) - -print("sum of adjacent cells: {}".format(np.asarray(edge_field))) -``` - -For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: - -| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![cell_values](connectivity_edge_cell_sum.svg) | -| :----------------------------------------------------: | :-------: | :--------------------------------------------: | -| _Domain (edges)_ | | _Edge values_ | - -+++ - -#### Using conditionals on fields - -To filter operations such that they are performed on only certain cells instead of the whole field, the `where` builtin was developed. - -This function takes 3 input arguments: -- mask: a field of booleans or an expression evaluating to this type -- true branch: a tuple, a field, or a scalar -- false branch: a tuple, a field, of a scalar - -```{code-cell} ipython3 -mask = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0], dtype=bool)) -result = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0])) -b = 6.0 - -@gtx.field_operator -def conditional(mask: gtx.Field[[CellDim], bool], cell_field: gtx.Field[[CellDim], float64], b: float -) -> gtx.Field[[CellDim], float64]: - return where(mask, cell_field, b) - -conditional(mask, cell_field, b, out=result, offset_provider={}) -print("where return: {}".format(np.asarray(result))) -``` - -#### Using domain on fields - -Another way to filter parts of a field where to perform operations, is to use the `domain` keyword argument when calling the field operator. - -Note: domain needs both dimensions to be included with integer tuple values. - -```{code-cell} ipython3 -# @gtx.field_operator -# def add(a, b): -# return a + b - -@gtx.program -def run_add_domain(a : gtx.Field[[CellDim, KDim], float64], - b : gtx.Field[[CellDim, KDim], float64], - result : gtx.Field[[CellDim, KDim], float64]): - add(a, b, out=result, domain={CellDim: (1, 3), KDim: (1, 4)}) -``` - -```{code-cell} ipython3 -a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) -b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) -result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) -run_add_domain(a, b, result, offset_provider={}) - -print("result array: \n {}".format(np.asarray(result))) -``` - -#### Scan operators - -Scan operators work in a similar fashion to iterations in Python. - -```{code-cell} ipython3 -x = np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0]) -def x_iteration(x): - for i, x_i in enumerate(x): - if i > 0: - x[i] = x[i-1] + x[i] - return x - -print("result array: \n {}".format(x_iteration(x))) -``` - -Visually, this is what `x_iteration` is doing: - -| ![scan_operator](scan_operator.png) | -| :---------------------------------: | -| _Iterative sum over K_ | - -+++ - -`scan_operators` allow for the same computations and only require a return statement for the operation, for loops and indexing are handled in the background. The return state of the previous iteration is provided as its first argument. - -This decorator takes 3 input arguments: -- `axis`: vertical axis over which operations have to be performed -- `forward`: True if order of operations is from bottom to top, False if from top to bottom -- `init`: initialized decorator value with type float or tuple thereof - -```{code-cell} ipython3 -@gtx.scan_operator(axis=KDim, forward=True, init=0.0) -def add_scan(state: float, k: float) -> float: - return state + k -``` - -```{code-cell} ipython3 -k_field = gtx.as_field([KDim], np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0])) -result = gtx.as_field([KDim], np.zeros(shape=(7,))) - -add_scan(k_field, out=result, offset_provider={}) # Note: `state` is not an input here - -print("result array: \n {}".format(np.asarray(result))) -``` - -Note: `scan_operators` can be called from `field_operators` and `programs`. Likewise, `field_operators` can be called from `scan_operators` - -```{code-cell} ipython3 - -``` diff --git a/docs/user/next/scan_operator.png b/docs/user/next/scan_operator.png deleted file mode 100644 index f0c1d03636b2758296da39a29251c2adc5b321d3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8760 zcmb7q2{_d4`>#YsMPmuAjJ1en7{e%HW{hEGY%^xji_F3dGiHn#Tb3}1BD=IuM9Ef) zNGq=*T2!o>O|GBPnF4uh5XStvIxj*;kzVEM0Z%;S1mFrh3 zC@8486L7xZT&|$7Y#n3;xcd91>w0in7Uk=PQD}Iu@somr%1)^(Ln@37;S0D5CN|iA zuS~4rJW-U?#0F<#Z5=EYTXFcop^{)>lvM;*3Lb&?LJ@}_!sl}SZDS3$wy{JYEDzQ z63(HqX|4iB2%L@){kxiIt|W>tiuljst*sDNpyyI&X`Gn*?^6ypn$HFkTASFo0>%DE zgP=)7Fp%RP&3Xf+SR(%I$p4QfgJ_752!FUWGME+XCBY$3{+2*fOM}?aqW#(8s2IFA zJeVMcW1=`b6Kf3GBNz_17m&SV;o(%8J&MTZxncd7uC|_GG6aGh7sHR?BhkJBES2W% zD+u9J#If;Vwpdvhk>pL2$|SBXESV?6%|{%AqghiVD25B(#Sclqhe&t=JR&M8o*C*H z7wcjt65u0cWIE85AA=YXMMTAhqCD6>@JM%G8#fQL=TY-J{;~N(0VsT$uz?C>V|F=RsidxZWaC zya?y&#tpR>hqFjRFA+s7kTA$pMsTc@iUi99ZApH?_-H#Zcm-$B$Nd+EJEVlQMx(g_S7Is)wzOsClqnJ5A(Qc6OF!?{$@n-)T& zuo!p>n#zwwgL|&KeONG(&asR36$5Z%aCXsnmv97!CZGcg#9`2^7!*0iFC^Z^)y10> zA7SkoN@nrc@h(1WrZpNH%x8-HI6SodQdeJB7C<~c4#BafdEjEgtWmzm2sSE`3ET(9 zRzmd)59hkE@O*bxxEtHr-PblOMCOOKx0Ct#auIgE0v92HC<(Q}ibRn}8ZsJ5Lq+p# zqC(>lWHdSk5evk_6TO1P{&rse-V`?4o{YAUcnIx${q1CQGSwc9CZl6mIJTW_gguu; zh#}MBmmC*f#E{vF2vj)6)0!eCQC*j0BmyK3Z+=b^4NSF9(X&v(8iy` zBQhe9NUj@|=?9k*C9F_)K1l+1W$`6qA75(@A%>476Jz4&{P0LJTFML|c#CbMQQ;9_ z#t3v6INFnAXf8N5TuPA$5w5}hL|Yl3YYXSXQ379qABq$v6btF`NSPPTi-U^tisaxW zY$7Lw%?`si zO57vy#4uT;5P%sI=N&@l3TQa2j6xI=Xm}bG;g07gsJZ2;V9*6Rw2;8I+WK3k7n?1&Z!J&C`{cS0cev#3>QoMxZ$w#9R5$I^ws8GKp zPa%oMMtRv#BLwhpU$KaY_6p(QL%cb0p7xPawDg7VH1aA2zgAt^lu)f9}=R%DQoG4Z8ClAbboxI2VWM++Dr!Y<~+ zu<1Ry<+)7%qlQe2i@Bt``CH00*Im3vNjGHfXB;iS82`vczq@&3ex_`G$E0j>?+p5( z)65?;JNH@=<7eK44oxl&hE66vh&wqc#7WaMZw?mEVoA?6Ocxed64795Uh7}iCQ!Fe zI**oplb;Mq>H9tM#@tdF^M&MYJYrrlYf*~BOm8$xP|;MqLNU(Etv!XqTuxPl?Z#xv z|4>!NG^A>(zF4_ZE>Jbml$dBkuB`Mh=2&cjUNk8rbQgIWjF@l9%dJ1<++FN#2%}?C za%=Y}7{UUR^lZbloj#Gn1{wm6dlZ@9_VqE~+&ysTu;II#e=icFKHMwM-eC85@%Q{W zLZ0E&_@#%ww-2tRsEzcWqckP_lFUuN5`9{npFA5jP~&q&@_yqQogckPf|+MQ$v%q< zzl;JCdJ4DppCNDPT7Jn63SoXnr+yt9s13Vu#_ME*)iKN6D2JFYC*~$qwM@OwdY1Ml z>pR959@Md{NeVfi7JdJ?T?}%M5ws!bW|=v?fn*bM@03|ttl@ydm$uJVJ;!WoPrK}= z>-;mU_)&!`^=49HSbwz!xzwS^n%jdLO8B8RP{*XG4R>a0721S)u3h)>;YqJcp%2g8 z$=+bG#lo)!2bOttE@~)ohQH<|M6M^OrPYt#N?9(Qygm~a^hXQJ8CK_koBeUQ!IR(r zB=GS3&(Chy1L{?;uf|Zm4d2EFC?dB#vY^$S(lqh#ZckM{7yM^BQuOAU_oljd;j7D4 z;MV7)Q)1Hrp;8CX;*(P@X}r7nCQaO49Sk9P_T7O2J9BLM>X&JmG{h+==5QizDjA2& zRXvT8GbAw+FI`W&<(!dzdc6NkWx}uTVL7vZ?>2<9>XoGWVAE>3Dop{yq`1twnznrx{k53pW^ zcb=c^7cIErx3RYN@!Up62jp{XGY?zNf2gK&Jm3GRGTa zNt=z_4tw87_|>0tVP{nFdu>SbF6u4wv4t<$Mv#XmcX^3!{4!1a+OFc!^X|_Bl$kHD zw}N$y4P|7_#>|%HJUQz{IlK|wj}m|#dHLj4`qz`mi{`%2LJWY-5(^;yCDqT3AXv^)aDtKsUlgo(SmOLxT7Y;Difn|H!> zrYx^7&vjT+Xt}#MernVO*h$%R!X;Rzsm|X0I+p%@P5dVq-8YSFm7;A--~Hdf%*_9& z{zTSQ4WDW4fbqSU9VZ$q%29G7e3rM_Zd^yqn>e%KYo$Bm^1s}ekFcFoSkn^eJvTj3 zd+N4o8$Nse%cuwGP%n3YCG7KEQ)2^%#-7#_Lm!s;Eq{3_bj|(J9rV|f;X(y0FGVNG zVA+X`R9dEN&m+*K^{e*C06i!>Hn{~@?U<^^vxb0Af3Dd@)%bQ>H;4*X*Msi^(5+J6 z=6N|t04eV%#~!E#+td{QbBtP-IQ=4YY!TL?q>cH!xTkVno>x5DR8g)Yf7)(X|LdZ< zQEbxzbW2&?DYeavZGW`pe%;sMj!-O@x$F#G0Cu{?t8C{Ov!XdqW^KSc%&u9xSy{H! z8*zSPxnh|U;omkv_W;KaX#e39e1Cf2$PTQ0M zFUDm4jKa02FBT`n)&r!eKd|HUE3LV= z!?r{5ep@znR1;gB^$uOs+ocOp#+YIjKDe)jm4y#mKKg}gy1n05Q#BytLyfy8b{3)` zG&LL~m5yOHru6OSH5`M;FK)c5*u^<*oCaT|!4w@Db{d+&$?dRz&%7-itvHjUscO5Q zMJHNAzM;DQMBv(YQ2q{sm_frbIF%mxF0+Oj%=5=|NmssQJ$sydA=5TJqai%wf4$~C z!>OAF-e3ylHwrzKI`(ZOkuZPx1T_W@PK{nHiw<|~IFd2CN|*jP@NmG?re~(#pIbFZ zHe4@3$-h58d~UDi_AU~)Q0vHpl`zpGagh#W`9(9*l(_y^_Q&r@warHedgs>_5P*6D z>s+rNcHRPADZgQ10G)GE)&zu+(ifwwgc&;PY#7|XHDme^`JLK`6YK~JFvF=!ou9^Q z(c{iZlh4cdwD4q^^1;jkLbo5>APd=a)vB7K1689UWg`KxNoTlqr1#N4f7m) zZ^3>PUtjq3@#MLD*o)-!djT&l?S3)N%6Z)_S*@=%a{ASks8DmSRmS&mI%ak&ET+p> z$r}O7s?mGDt=YBUoU?1O)kDiJdi*){$n03{d8+v+;XU9IzfXTj4{1R1s2X?Iz8jMx zFLvK^h#0gl56v_yQ`BoYBl+at*=Bk9S8R3 z?Q~3>n|hgsQiP2wzYR_upP;#vg_SfI{NA=r#Xj9(+5SFVeXV=?`KN=f#pcyf`|=C~ z&l)x*{{Hz)`t_}g@OgzFa;a2H=i?{WJtbTQ2$Cp;D$=y!SimCVI^ru=&1Q zbp3)RgKIw??7)`AOy7mG(A<}^4$A& z`UN~_89EC8qRVlr58G3_f*JQt1bL{%`R!))*(#uo=XC`EnVERy20S%Au6C97mQ(7c zYU`1Gn+g7FDF(>Lbj|3BOXNtIgQtZi{;@{-AwvF|BWA{4&n)o`Ip1AUY^oy=iZFpuu~k6L;;qzM zQXnHx$fFc7Al@E^Ml$>7rIo97BUkJz>h8`nYDVnAX0UfwY>ptgy;OOtn5MEOdL$eM zaO^yoNRk?K3fJ}e>dj~qi_EL)_(93PWzFk{KR(a@q=S;To?G`hv%89Tq(AbFP?cw# zbv4I*cA>J_Y~jn*$$mET`5mURigt-$_(o$vM|$7g+k;Qlw#Xlic5biLUNs|s zH+A^;se;`$PgS%buaj&R!p)palX%;~^lh^xv&#$`wnEbjfi0x;ZC9=cq+7!~!##{O zSIWP!EXQQ|dYEsBJW~CQKH)-Zzh$7J*_*pg0r2l_CQwsAgNZd{0aa1xf>RcC*Kh+ z>v(f^jbti4d0oc!Tg3w`n!@&c$V#S zUuP3SeTL7tV-OB#t(fNtY=p$n*RsGI-t6eFtY5o&Ju+i%X3`7b(`Ps0bw@f>6&_sp zq>gD*WbnxRtCfornYFfavG@5v$7U0( zQ#gY(3T$#m^!TZK6QcPlgfY2A*M2!HbGXQY?rZY=>82BY+iyWl>vA$}?GEW~Hx+b2 zH|3c>KbNR97mRI{y!Vi8iJPBj;osV~?A&Zk;Ptk-kUip!+Y11b=cGpqRWvoqisfey z9-96g-DK9}@Tf6&U$b6nj%4l zA)i7tdbWcw2i@NLz`S+Sf#g4zO%-@KJt`AN8+R6kJS?kE81E^zBwpV;Lo+jJb!bnJ9tp!kCA!3m_VH*USTzPIRw!%otkGwhkRC#A`!p<{vle8&`c5o?$|tZkt6Q@JDe z;4qt|xWz6JB8Lrhc6F*bo4zzg$n&ocb+UJ8<=!4M%7hgz$Q6UPEja#t1>i4Uomshx zFue%(8_^*zpbN);D+Vu|TXnU9GpP$nJkaN)@ItGntS`m@U=to}&^hlN)T);~QWp(7 zR%_~>aI&vx^JV;9Wf(o%zUf(9wxjx=glZxWTjG8CFkJ5&!hLGpK_@36yx_;rUL)^JJ4lX?gNNlxg&NfAohkP96-OpRKT*7zL~ALg#U)|E7YA6NW->v zAX)Ud+S7CUMooyy_f8z8diS+s&fTTXAemHNJ`C(V^JLOoCJ1M7SEDCxv*_(l`+K$ym6jaEOD-6=iduIkTGBWNNsk0b1e>g5SC)7dFJN# z9*b>jm#-|JEwL|OF@Lkd_0Z#yy9GrzUxU!J|7Q5J2J%sx(CEj0oo&nZL;W*yGSwZU z4{gO^Di1ZU%s1_bUYQ)_u@9!LueC|MZ64k)d{m)){o9=l<_8epb=kW{kp+7z!$BUj z7uiHKax#s!ot_-h?|Gg)(yh@`V3sq7h z!?J*NUtx`?vr0?c@km)^G2}wwexcF=QDNwLBfF?dZ=14%boA7#q7#QH76wkHp4~?s zWaiCb1Ll4O8b)~Uhh>gi1KOmXZEfcwoR?;3Ennee%E*u0(thz?@FGw&#t zJM&VFX!<&G0GuhhBJ<)^-ic1T+JNfFCsCY-Z|Xj%d$uHwO6y+T>WX(@&D-w`vZ^n& ze|_W;y0sc?E!3jDCJYNG^4`&c_x3gW6k!FH6Cc2+)oO?6z)Dj3EG{?AuLX1`WL&FR zvgE+Vs@LZM+`|4G_H$h_^MjTXNlSFMLFc^rN`M1Ecjr&7nxy|%x;sAbzti2eNQe0_ z^&MZ0H9JrhUMHQVzV~i_67|sWqnKbv*M+>$Y&$F`Bpl?>`9;N+yL~|kBVf&AdwF+{ zf5RW>msg_9vNxjhL2mRuUQENL`A^n$yySZjC zXj_>sa_WAU@zd(Ox$}Y7<9qPx7J&~=pxL(WKR(_u7`z&D{=1dSiod_SzG`v6EG5_0 zjyGPLB`v%6v%S;m*n8%3qx-;JuX|FU_O5U%rQgN4o;3`~sTPh?hWt1E`oPw-(#n6E zv)S=j@irz#Hr9_gW4Hk`zWg0ObUS1qBdgWzc|*XC_p^m2tMYxWSS-8YH42tJD_FaY zk^7b#JNq@5YbL0<0Lwx`XOHYrElyM2do1B6qA};qMZNGRXVij!+Aw$S7x!2D^mNI0 zE3Q{f1VQ`54Wv8}>r|?nfJphHC%H2(L`%ojI-72lC}9$Gm(s(q$rux?uMQq#P)n^_8cv44Mn&&E=am%*CW zI8VzjnoVsYr1YtmPd?Gut8jNVL^imwl-2uLr1lN*jh_5q;C`dc3Btl{zapeMcIHBfX~CXk0%%g+7KD@TV2@_Tdg<}(Y1uL|30k57ts<(J5O8H zu~TyEoAT=Z827<_a(-hrWlhk^-`c;$e?0kP??JT zkYwzuZx?YwB2(PA0VQ~wg{bB`*HwO^C50VM-v(J?iQXgI%_`GZWS#XoNs9>RezWG` z`pT@ow2h!0R>MmLzNyhieT;^{>op6OUnj1J-zi!MRFImOUP~O+)Rn59)d1_<*|Ftp z((kz*WS4pP)4Fm{5b_eCi7M=!m93PN>^Itac>lD%! zv#}Yu*%>@F$>U1*ZS#=u+^rf08BnLAEQHSDi)HUxwkESav6MPMj#cw}s7>~{A7v5q zWe~}}v%1-9M4JumkgvIt)fi{#eQKXWNy&v17hOuSZjM&k#JtndoQ90{e*SG2G3foZ z9i#3UxG$fQ<(#qGW~~K{I&^rcwAm8zdJCwaortUo+wj)0+;Nv|_IvnmQYbTeld1;w zEm!54)nNLir4nCZ*SAJ?*XZLaqTmp8TJ&n5e)o}o^2N9IqBg*cxHewy3nNaQUx%&y z=~lgT_vil1_bo$7rwv^XR{O2#MV`H21iR-4{VD!a^Z3)RsjJ>uSSM>R%-r-C;oqqzfwYcEiPy?mVpJ?ZGbD z(*s%%+nW$oD&W?IlGT|YOuXIIxke$9oC2@BHd_LbCvAoPY`Za7hpJBnMzOuDz6&KSsG@pDV9n} zASq#T`XtZl$5pvrqlY)9)hmH00+CU0wE<*~u6EPM2fj=_6BTvNsQz@#_=8ebB3|1s z+z#!z#QuLw6cx)u#h;?L0g7#_i=}x1U*VWWY+LaLU1n`1-dffgt9eygm)aH1z)Uv23bz1w{F!UwNnHtOyjgQ>8j>>IAj zEmRq9_m0^fI)>zT(AI2%$OY@NT}B5MF2Z&uE!YD_$$NS}$P1)Lw1cX0ftVD1D;2)^ zM{hdRUF&oIxzo5Ed#CF>HUJmQ$#|{1sw4K^_6?Y;IKtLLN@I$!4I2wKT`(2=;N`IbopEbHB{- z_iEZ5b1qnQX&YR`ft2l@Z|G3NjJHAj?9=w2*&6SJ=+v9gt-K+{qUSky-wa@LH< Tz9PZDf+)DVdg2-|ds6=g`@|Ku diff --git a/docs/user/next/simple_offset.png b/docs/user/next/simple_offset.png deleted file mode 100644 index 660abe87642151d390abb723a0c40e5ee9c22e00..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10292 zcma)i2UJtr@_xXpAS#Lq8mSQh=_G*!1W^(QC834TODLg*1OlN25Ks^sXcQ5Gf(79! z2q=051*E7Lny5&XB2rYONEd1U9rWG%-dn%5-hX8!=bU|Zojvo-H?wob%EEY?&^{pu z1hUQ4#E=Mq@Ik>BCb$Wd;C0@2fD0d!XpDgry_Wb2fo!q}HF5~0NBa2FydiQ(?5{gH zgoa-bGgJ;~D2G6JFc@kSe-Gaf4?0sV&^r_qf%o(viocJ)H|19yga!hs3Wux0Q6xBA z4r!o;1V5TcHIxRz=~sOZU+=)*9hwJ4`O|0~atNH3ng;0Vpt*;ae_&`3(@zeG0nes^ zq24rb1;yak(gyrEfG=Ew4A&x~6~V0mgF*AQ_x8m5gW2G;kZ3g|3KSp2n~=;2atJKA zr}+nXgNw1ZS3nT23Dz$phz?3{aJY_|2JZt3Z9IHDLj3=&3*O4SLp^+dt%~XtVc{4W zhNpT)nt2BV2HH4k(SNNb+&hHn9~Ag|@d!1z8feLD9Li#N|0<<;hx>bh0TFUYBOut{ z5(HHOJ-|HHC2IwQq6+`j(Eqn2qtNE|XnHV%OtbgX)<$djN16rw>ICf+LiA)XBXFi3 zK7PzlytzHr#=+2u5ekPhHA0ACKmr&|x?v6Q@Z$|XD zh-4v*0wXoy)<`qFHPwsZNHRnjA_;aDIAfp~7-J188eFIZuw^(V&CZDBi43=oG|)tl z;HGwF7)!J-9HU{ZK_uWEiNRDIPqdl0E!@gHjAq5#OA8}3l0vnmc?2`Oy}YcnoiuI2 zHMD$?NK|Mz(<4gD!N@0ABizu;$;6lHW9SD)(zFfIw)Z8HJk5hG;Ix1MYl4F#&K7h< z471k3qbMi}4vxlQLp*6=rVa>AqhMPWJ{ZsNWQ1$mVS`aVpeDiC){$Zu#WD|u(@18- zD5Oy!2IJ&Qcf|XVtVm$`*hs9Ct%+}AUTfW%qZU?Mzi9W)FGgh-4B9O;N7 zSQD(7RzB96p}vu!7(WA`ShxVIC?5w_qyfeQWr)zgn;4lBtWB&gql3m6Z*PA`dLYv(&<JhW5e&4WFNxsG3PAW9n1in9aGa+nHJBbuhof!1 zeXTHbB&d&#Y? zorfulrD2N1vm9`wC3@Hjr1Q$z`k2MMhl$dBm4faHbVC}W(Xo9hmPbk8~mVxs! zv%v;nsZI_~lyHsUU>&$9N@F!OW-P~yQ1q}rL{fX(6H;Y=gCZ2*Z5 zo{gyXREnb(uZ<}yl0*t65qSm~7@>s;^e1T9S~_a`*hJWcFif;8!w`{nRy1qP0Lu^y z4LFmxzi2HjLyM?TxQ;fF8Xgsfqz3CyyiIioek@B?co@Th7G-5`8MuWtDVa@^Y@*oXid z%}kHo8n$<_*0KioPD=+omNk$Q4DeLUkf0Zac@8GW6Zy>M-&VsdqI>)_MqzT&kt zovInVNu7m0804qNn%0Ze-OaDrtC@fKFbGXq6)P(%;e!k=TZk;k8JjYRdl>NQiqviy zir#6ziY=JBGJ^+}7uaWdJ$&+~bZIds6O)4z^avL?$pRu>M&TPlo9>hEF0nViyle8= zlxs+(HJ1o-ejG+sVlZ`&4!g-4b0aTa%%a%Fv+MI+w8sTGjEb`%d5MI}kEI?R#K^~9 zFfV4a;q(jX^&$v23Pr(DREAQbw8ughBQH^W>zX2LS8kz7`lnK@Yoa?51CeQAc}?K` zt8lm7--S4my5a796|9WI(I|OB#z)^1F$uTrjDy%WtDKZK`9>GB9Uu9L96EI9O;^{M z`T215v(v7fRA{9^(rSL?-pi7AU=NQig{wH0{PC^5@cAxFJ!_GRY`7?1iTj3Z~I zpEbL40|lX$3R}wS^WyEfdb4k^89Ge)r%#`1gpIzZA@c*$T6a|mYX=lyv&aI{>NBYK zXVBa~hKgZ}I-#z@gyU-3qmk~i`P-^yJjQBNn{RG@O$|{~cCFAM3 zM_9j5YMNQv&|5G+aC8mX)h`Q6O)nXkTUn=7i ze zL_v(-hhE*|CkUuk^%)E4Y(+7Ice0PZ;K}bmndX*4{0DY686~n62OljU^M>j_Wz17MOu^UFdN+IMQAE}D#zj8h0-{n&WoI`upo1;enU{ucNFfq6S z`AM~@38F3E6?_WzI2%TJBy{J))$JIMCkL8)-k;dmew-Coc?GWw$xr)k)wpXrc6(FC{x5#of6NV7Y+>%zg1c3dMvz5SP-f^OsromacO`e zSc?ToP4nX)9#HCc-g7P$w$z(4G}?rz+d_5@+h18yfPnt|(2(?QEwGRL4?bL;7#}wx zlcixcgdS0T=uhWl&RJ^`Td@|@p}ur{8mj2ioN2^8xL@AL#((jK5C*$y)^Qb6dr+7s zA!f;q ztjsy79sUo5Af<(vdSP&?Ii)fxU1a}}JTaB~!c9xxL=mhEEccG>oLJ&0NAru!Nw7;^ z8=AndL%}1vTYU$HiBgSbd|*B^8F|C{o1uleiYKgeNq*zG_Te)I2YXUCH($Rs63|EX zK^UDpu@)={UCT`Bl!fJlrErq7%JEL>;(u8Vs5i<@)DMZASk{AX&(ExSx4BWye`Lg< z>oFK}+vV9gk!G%w%BRPx>R$Do%vE{+slw)C#q>J?$f^E3;x5ppF zXh~1`c#=_yOYozu8-q`vLB^-xOCvy0^`6onDbE;~aPRDY1&kxPp`R%J&_&5M z7MZXM;iOy!?<5)gvK3A;>ar(2h}Y9Vk`!!LQG~v)1=p{QBHv0fxPtHgcsA!SENAW_ z`>hhtuY)S-w>ejx?pGgLmu>ZarRSmRvaMi?|!KGlxWb#Ydg_+T|0=P%*A8Edc z$!7W;4x))k%fr37bvNC41_r$MeZRN+Rop7}qI(z{!YBJuf{uN2W>iJQJ66S3j;g>Y zr0KgbFD@A)FTmYBP`nS}d;J*dX?gk7>e8e%+r9tUu^&SUM;z;^;S)V)F3KAF=-UW# zLK@|;Jagwm82td|-t&YRJYBCa7(ia?=Gz=Qb&l*l`{ixgT0}MUQ0PHes^jqxbnMsS zxIw*zJ4)$izuRx)MAf*=BYa;==uLBF)7D}?>-JmwHxTk%0C%wcA@&QS9ejS=k2#aBfkA)USD$YPF={3o~W7Y zAYTxX#ocMyimAJrsKnhnozAZRy|$tH)Tu$~wX*4;PcP2&z21t+%f0+=+F&D;2SFc* zWn5v6$zmZy^uk0TMRlU4JfP^wh*#eeqHoQKjTPmWElK#|F-`(|b;Rr1XSe31f)mNn z;syl@Q4!-K$BbqCEFlJVFBIFED{t5eCUYAOH8wLysBj1xDOR1fE{kaJMQKlkl(_h$= z?QTx`8Nfrx%en^n=~F+iLy^XXAY`KdsQSWP zw<_1e)qnM;2J2t)`Rz~T-yX>GYFYwN136pH@?kv*mxQe56;17fFd-EeZ?y(N%J`~n zAgg<;;~?5ws>%F6f*@<8@o%vZblcq0c8Jub@3gsv8Ay-wN>_7`^;y>K!Ouj_$(>kYjZ6;*KT$RuQWWBieOwRtH1-oCSqAujR*B_a#s z%%VgNpF2$pzZd|yz|ZI>mP%B=Y#V!(W*TRun&t-i4v8OVBjkx8&>J$wo{33RLxEd^ zFucgya2TBO&HaLFlGTcB&0Th_CiBiNX3wD2@2QT%W(5ib-xs4H%K5o1WuvQ2Sr9w^ zz5NhGeTJ|}k!3$8`R^_7vvYejA2G99cdfom{d2f@d0-XgM*0Ww;H86cTR#3AJs!Ay z#^IT6Nv3J7&3#bZFD~tHn{LCbI|Rsex&bTE|=W8ngwp|K-|n)HfrwRZATqv zxutDej~rxmW3>e&9m04M-17R7t?Ymi#MllU(>sHlXn&Z%j;^coHWXW+2y$*&czqZI zj52Qb@Fnq|5qxEuOLb3KWx7T)1#g1mLb}d2Mp+c-Oh`fEY>!R6tW2vuygJ#SE^u#% zPq)5q@-D<}5cr6T>{{{_d{64zPll3-8_cUKi@Ar_rbXga&HJ25ORMtsn6}Bmt?#SK z5BH8uygNFsxqZ-dBRVd3*A@J`Ci8r}$}0{qfRMX)j+%i;QU5;yf?{Ud|t!=rUdANhr=#-D8S>UmIP-klDV z0QBvHcsG~V2A>^lUiv1N+VwO{St@QhxMF#_QG4wEbH)hA_tRPBqH69%*|tra_@Q%$ zkfAbwR|1IjVfluv>4w(@Q|*XT4?YiN}voP1lb|lh=NB z*tUxxlonkEn+*J#{$Yp(Cz`SkM}LeDsokvb@v==%j{Fk2v@c3oOvNF-t-yZp#Fide zxh`@dA$&Yv6)A8N&O;oYR~MyFJ9>J1Q-G)0bY!WkYr6g+%jFW z?RwuKMdcHRTteR(30z+Ol5>GE_Tbg0gzfs(MXu$HFE=L&k94^u5?I5R;Bk&>EVcBr z)Ak!VLu&8Z?8zy)_;e}u$7Xc@3qXGuJ;ASPa=SA3z;djM-=SvGUT0Ljo3>JQaF@(j zu-^!cILBQ6G1Xb(y(RG3qEF1~%=K4!=2!32x)3}{w9p$N9r*O_s`BY~jwamOtpzo2 zISFqv0RQ_;$U7Kk%MK2RU771pj<|#h4DHvSzS>cHa*J|Hy+p|ArLvwebxdAGNPPzI zwt84ly^S5bD=b=LvSP+!xc{QDl%rzWqaX#7*#neJ3@*yMlj9NPX4vSC4p_P0p9p843$G&BSae7jxMoa5o1Fuy!Rd81P-xjzTYAeC{UqS2F z=l<(%`u~nOrp3Z-HvS?jJY1k+)gH}{ta!`sax%%iUi>@a$hAK>CdMQFwZaIIo1XdZpW8oyGvK@J+St^3F|YKsszNq{w~5-{pGsVy@`*<|#TDd6&tU*Vou@ z^Xtn2c_-P%4-*HiWGI5rLjKWtFi{6JQzZH?_v2x}5W}9l@9)3x^4y;CmdND{^ezR9 z7hnIpUNoJ|`J0J@sF_PyC%g-N(c1x=r9B^3cYeGT>CO(OQ#$tM{+1E(8y!+21NJ z7k?!kS7AIpAMx@Ri=Kpymt#eLjr}38d|O2SF?4p!v?&irJ4=J4UuzF;C2$ZMv8RM> zx0rD+JF7vb(}#(|!$gtQd>FIKJN@RvhcjM{5jR*rzy1l2n)MR}Ih(YY(G)UZzKvL( z3~n_1>VJ!R#g3~edhPbu?;BzO`+w11==^{&fGN|wS6NvJi+`5^4AQz_-~kBFhty%I zSb3YoDk@jgY2h*b1EUH4inKb^pkUr*_|G(WvA-x~exmo>cKsiaQ_`{8xRiZiZ(=O< z7bz8U?arN*(P&)q-iwPN1n%Y1DSRCA6#e2hq;I~_A^_Ha;>x5}3*Bw4e_Fnp9&X$& zxHMW&>c}!ot#X!_*7@I11}z=Q&X$T9RwmbEY6S>i`WW56r5505;DhB{fin-e>Xz^U zpqG%bd9esUc+}>msQ_<8qFZ+R{8gU220E%<%!>rvMo0RKXB#HuQtJbB6lGF7c4H<1 znMfp*!EoHfZ&@uI5O!*@a=o7L)Ocs3oT|mXQU%KUryxx*l{!U|0*R;LOAY-p`#>HD zgA+CWUj}nYyed0!YlS!l7a+6nUulx0i1f$^|LTXQ|4A(kWG8DfZMKLx2#l-@XA|5- z(mbB>(gF1!M}so|PG)`W6$6oVBH+8ZnZ(+UBP)eR^ZxL7C;WA!zN1FSq*36fc*GE> z%^{TW9ZVwR6z_OkLJJ%#$P$!I_1KZ0qL#B^orO87-K_=DeiJbHO>a9vA`T|o+;c+X z<=_EcE+R7YgP;9+F2d;izj6^iPe}R8KAS*l!;ax~|96@PSh#E&@aNj^4#@1?~5w9ob3e*^7yVJ@xznI!&-N=-rWv$F%wPYoikFK#V*I;xZ` zp+g)8u(jGNC(ARwuqKeO+lOE+74)q$q}PYOsm#vJNe(r4KJa=FyAmZ4`(r?R>p(8Mw&QZ#>ZcR0DkKG^X-Fu(UUJh`j6zh z8J2U&avO+ocOrC-<_Q9K2t@}FJ2I$YE44L|)>^cD2o4 z7W(6)f$i1d)hk2Uf#lWc^asGfZevJ7ZDcG6G<3flbkPDM&$OTl-z-O|raxdJ|(EDp|Z?E?1 zk3r=lYhR3y)LvG!(Kz-qRE7NP8rxd&M&PwYGv6k1;Br(F!Fn}cW~O$F`p0X`*9Oti zx`%Xix0`(bMqb;eaE-wq6%}>gw7yHEOog1 z9Csab#<`o8y`Ky{aZzd@2eGFxmp5hYKv2DH?GMd;Q9WcbP^`&pjBaTE$@8gRs+8rH zSHSQVdIUKVNa&hJX_aw=2#hjKM6bo>4r5h7#u^6P)){-{z#uHUt7Q;Sl!G zH?=-kcfs3?|1mj^r`NB9=z0=^=l_#Na>AeS{D010ow3pDxS`2>IC;7snhAKuR6yCw zz9!)QbA*1=5IkK4#M8V`>)qTHs@diQ0C{Xhn3CQ$#oy)GlygPvl?Uo1{>lFh_2%XY znKcT45t;yRXiIYO8*g4}B*;Hvn;jwl9v5G1)v5X`t+3BtpOp8q(e>*k`|*~-Zljfj z$B!Sc`ew88O6tH}!;yuMs)y4{x)WvCPwKE%Z6Y{_*vZ5m+AQf3D8NvN?=+BAlJ5TS zp(8vzd_v(|b4M07Y3s#EIj7=gxsgK!G6;@~sKy)C&gx4?Q5vpOyKmcvw;E{EyC3<2 zW2)DCiZcb;mEF7hhg>@!e7HhE2WG*&1FerHi7$<|USM0}zZ7qs%8$ydu`z zANAxiKjPNh&n@ZhQTr*$PQ&pJA0I-e=Xcb;`f~Z9c|lcG6)Lr{-Sm7R@iK_?R9|&v zj%Bjp{*5(V$r}j#oS(E;KW3dRTMkir1=-*8rJ$#QzOhUL{}rR+~(GF&F?WX z&*pi`W4XI?9(Gt`o`}m;mRuFC1u?6As_bafqpF1AiA$*Ld-~D4Hd2VRW$LlyzYfE8 zop~sBU0y70dwXN&C3o@)Hda4d`#e8F5fwWVrrgs9^5T^< zhi3lsNDTSh@1Irsyn)ZkClbC0M`g#2pCK-JU>yxwL!#Qzu{Jz`7Us%dlsw8c^tn?! ztljCd{p0COVJ2V7KXdO`+;i+mBgDl?U~lP{?R+uFmkK}6GzfefCJJ#+w~HT}jfpwl zMA>xn>ju>>a0s?y?D2G~JgblZjuX;)LwDpIOy6tPMA^H6e`AcEbn4b$!W9~lynHI- z+3D|M;hk_rCFy?&r1OeT=3!W>DGD{67eANM4|YeBIg{rqBPoJ7zE^Zyvk)u)vQ8cL z{4mF_#in4(uD>^?Bi~Wy^*IjV{3qf0B5-WcYLIcuW#3pV+PTUfkRI{yP8=U+w|JFD zd0PM2X3p{>$=So_{$7TzV9p6PG%@e0Q^~(ZMdrrwfpm<}4i0wfjrhOiaprSWjnHY( z(5O7P4*cWnH1EFvzhvQjUqCF$OWv#Iq{(NH23S`v;%!uZx_39Z&y~33Ock&`3Z@@) z*Hc{T8XHct_+s#wr7KDEjmG+ZS(tc1)n1dYlIEUpQcZS9^WXO_Xiwn^ zHH)xexYy65Z^QjKrK1}&mEH*M@JB*trgtJnR6qS8YkOn0y!2KA{*yRtm;WD3#}QGJ zk_^8A3^(I~frW)$=v2@aV(mR@z)@XdfbULXHiwtS>n`Af8@ Ubnz(o8z98g$ilD)<9_1*0a&l~Y5)KL From da1da20b0d6bde48e1a15b6ca6dee9e7d065f337 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 16 Nov 2023 15:10:27 +0100 Subject: [PATCH 34/67] feat[next]: DaCe support for can_deref (#1356) This small PR adds support for can_deref operator in DaCe backend. It also improves the code for preprocess ITIR transformations. --- pyproject.toml | 1 - .../runners/dace_iterator/__init__.py | 32 +++++++++++++++---- .../runners/dace_iterator/itir_to_tasklet.py | 30 +++++++++++++++-- .../iterator_tests/test_builtins.py | 1 - 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2d2a7dfe9..7690ae583e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -330,7 +330,6 @@ markers = [ 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', 'uses_applied_shifts: tests that require backend support for applied-shifts', - 'uses_can_deref: tests that require backend support for can_deref', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_if_stmts: tests that require backend support for if-statements', diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 9f67cb26da..e3fba87571 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -65,14 +65,29 @@ def convert_arg(arg: Any): return arg -def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]): - program = apply_common_transforms( +def preprocess_program( + program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode +): + node = apply_common_transforms( program, - offset_provider=offset_provider, - lift_mode=LiftMode.FORCE_INLINE, common_subexpression_elimination=False, + lift_mode=lift_mode, + offset_provider=offset_provider, + unroll_reduce=False, ) - return program + # If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG. + # In this case, just retry with unrolled reductions. + if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): + fencil_definition = node + else: + fencil_definition = apply_common_transforms( + program, + common_subexpression_elimination=False, + lift_mode=lift_mode, + offset_provider=offset_provider, + unroll_reduce=True, + ) + return fencil_definition def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: @@ -156,11 +171,14 @@ def get_cache_id( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: # build parameters auto_optimize = kwargs.get("auto_optimize", False) + build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") run_on_gpu = kwargs.get("run_on_gpu", False) - build_cache = kwargs.get("build_cache", None) # ITIR parameters column_axis = kwargs.get("column_axis", None) + lift_mode = ( + LiftMode.FORCE_INLINE + ) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] @@ -173,7 +191,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg = sdfg_program.sdfg else: # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider) + program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 5d47cad909..5b240ea2b7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -265,6 +265,29 @@ def builtin_neighbors( return [ValueExpr(result_access, iterator.dtype)] +def builtin_can_deref( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + # first visit shift, to get set of indices for deref + can_deref_callable = node_args[0] + assert isinstance(can_deref_callable, itir.FunCall) + shift_callable = can_deref_callable.fun + assert isinstance(shift_callable, itir.FunCall) + assert isinstance(shift_callable.fun, itir.SymRef) + assert shift_callable.fun.id == "shift" + iterator = transformer._visit_shift(can_deref_callable) + + # create tasklet to check that field indices are non-negative (-1 is invalid) + args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions] + internals = [f"{arg.value.data}_v" for arg in args] + expr_code = " && ".join([f"{v} >= 0" for v in internals]) + + # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution + return transformer.add_expr_tasklet( + list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + ) + + def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -318,11 +341,12 @@ def builtin_undefined(*args: Any) -> Any: _GENERAL_BUILTIN_MAPPING: dict[ str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { - "make_tuple": builtin_make_tuple, - "tuple_get": builtin_tuple_get, - "if_": builtin_if, + "can_deref": builtin_can_deref, "cast_": builtin_cast, + "if_": builtin_if, + "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, + "tuple_get": builtin_tuple_get, } diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index d5d57c9024..2bcd0f8367 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -250,7 +250,6 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) -@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor From 67a618856331a97530560b1607a7d11c3c3ef802 Mon Sep 17 00:00:00 2001 From: ninaburg <83002751+ninaburg@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:49:20 +0100 Subject: [PATCH 35/67] feat[next]: Extend astype to work with tuples (#1352) * Extend astype() for tuples * Adapt existing test for arg types of astype() * Adress requested style change * Add extra type check * Use apply_to_primitive_constituents function on (nested) tuples * Adress 'nitpicking' change * Remove previous test and add integration test for casting (nested) tuples * Adapt visit_astype method with recursive func for nested tuples * Fix integration test * Call 'with_altered_scalar_kind' only once * Recursive 'process_elements' func to apply a func on the elts of a tuple * Fix execution tests * Adapt visit_astype for foast.Call and foast.Name * Fix tests * Rename args and refactor 'process_elements' * Fix tests --------- Co-authored-by: Nina Burgdorfer --- src/gt4py/next/ffront/fbuiltins.py | 6 +- .../ffront/foast_passes/type_deduction.py | 11 ++- src/gt4py/next/ffront/foast_to_itir.py | 35 ++++++++-- .../ffront_tests/test_execution.py | 70 +++++++++++++++++++ .../ffront_tests/test_type_deduction.py | 4 +- 5 files changed, 114 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 13c21eb516..7b96de8e89 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -196,7 +196,11 @@ def where( @builtin_function -def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: +def astype( + field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + type_: type, + /, +) -> Field | Tuple[Field, ...]: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 605b83a5f0..95c9128f87 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: + return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( - value.type, (ts.FieldType, ts.ScalarType) + value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ kind.name for kind in ts.ScalarKind ]: @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) + return_type = type_info.apply_to_primitive_constituents( + value.type, + lambda primitive_type: with_altered_scalar_kind( + primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + ), ) return foast.Call( diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..816b8581f1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, dtype = node.args[0], node.args[1].id - - # TODO check that we test astype that results in a itir.map_ operation - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, + obj, new_type = node.args[0], node.args[1].id + return self._process_elements( + lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + def _process_elements( + self, + process_func: Callable[[itir.Expr], itir.Expr], + obj: foast.Expr, + current_el_type: ts.TypeSpec, + current_el_expr: itir.Expr = im.ref("expr"), + ): + """Recursively applies a processing function to all primitive constituents of a tuple.""" + if isinstance(current_el_type, ts.TupleType): + # TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element. + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements( + process_func, + obj, + current_el_type.types[i], + im.tuple_get(i, current_el_expr), + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") + else: + return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) + class FieldOperatorLoweringError(Exception): ... diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index d381a2242a..58181fd7a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -325,6 +325,76 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +@pytest.mark.uses_tuple_returns +def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures + @gtx.field_operator + def field_op_returning_a_tuple( + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]: + tup = (a, b) + return tup + + @gtx.field_operator + def cast_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype(field_op_returning_a_tuple(a, b), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1] == b_casted_to_int_outside_of_gt4py, + ) + + @gtx.field_operator + def cast_nested_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype((a, field_op_returning_a_tuple(a, b)), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1][0] == a_casted_to_int_outside_of_gt4py, + result[1][1] == b_casted_to_int_outside_of_gt4py, + ) + + a = cases.allocate(cartesian_case, cast_tuple, "a")() + b = cases.allocate(cartesian_case, cast_tuple, "b")() + a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) + b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() + out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() + + cases.verify( + cartesian_case, + cast_tuple, + a, + b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, + out=out_tuple, + ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), + ) + + cases.verify( + cartesian_case, + cast_nested_tuple, + a, + b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, + out=out_nested_tuple, + ref=( + np.full_like(a, True, dtype=bool), + np.full_like(a, True, dtype=bool), + np.full_like(b, True, dtype=bool), + ), + ) + + def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 7800a30e41..dfa710e038 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]): def test_astype_wrong_value_type(): def simple_astype(a: Field[[TDim], float64]): - # we just use a tuple here but anything that is not a field or scalar works - return astype((1, 2), bool) + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works + return astype(broadcast, bool) with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(simple_astype) From 39d1c0958c06da22973417660035ec8e023b6956 Mon Sep 17 00:00:00 2001 From: ninaburg <83002751+ninaburg@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:17:43 +0100 Subject: [PATCH 36/67] fix[next]: Names of variable in tests (#1362) --- .../ffront_tests/test_execution.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 58181fd7a8..8787b7d7bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -338,33 +338,33 @@ def field_op_returning_a_tuple( def cast_tuple( a: cases.IFloatField, b: cases.IFloatField, - a_casted_to_int_outside_of_gt4py: cases.IField, - b_casted_to_int_outside_of_gt4py: cases.IField, + a_asint: cases.IField, + b_asint: cases.IField, ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: result = astype(field_op_returning_a_tuple(a, b), int32) return ( - result[0] == a_casted_to_int_outside_of_gt4py, - result[1] == b_casted_to_int_outside_of_gt4py, + result[0] == a_asint, + result[1] == b_asint, ) @gtx.field_operator def cast_nested_tuple( a: cases.IFloatField, b: cases.IFloatField, - a_casted_to_int_outside_of_gt4py: cases.IField, - b_casted_to_int_outside_of_gt4py: cases.IField, + a_asint: cases.IField, + b_asint: cases.IField, ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: result = astype((a, field_op_returning_a_tuple(a, b)), int32) return ( - result[0] == a_casted_to_int_outside_of_gt4py, - result[1][0] == a_casted_to_int_outside_of_gt4py, - result[1][1] == b_casted_to_int_outside_of_gt4py, + result[0] == a_asint, + result[1][0] == a_asint, + result[1][1] == b_asint, ) a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) - b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + a_asint = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) + b_asint = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -373,8 +373,8 @@ def cast_nested_tuple( cast_tuple, a, b, - a_casted_to_int_outside_of_gt4py, - b_casted_to_int_outside_of_gt4py, + a_asint, + b_asint, out=out_tuple, ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), ) @@ -384,8 +384,8 @@ def cast_nested_tuple( cast_nested_tuple, a, b, - a_casted_to_int_outside_of_gt4py, - b_casted_to_int_outside_of_gt4py, + a_asint, + b_asint, out=out_nested_tuple, ref=( np.full_like(a, True, dtype=bool), From ecd0b68a4492a6f01ac3f9a66e888da7c992a0c0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 17 Nov 2023 22:31:43 +0100 Subject: [PATCH 37/67] feat[next] Enable embedded field view in ffront_tests (#1361) Enables field view in ffront_tests New exclusion markers for some cases - cartesian and unstructured shifts - scan - check for a very concrete error message in parsing: we should match this later in embedded Adds the following features to embedded: - support for scalar broadcast, astype, binary functions - adds `__ne__` and `__eq__` to Field TODOs: - full comparison operators for UnitRange - full comparison operators for Fields --- pyproject.toml | 6 +- src/gt4py/_core/definitions.py | 3 + src/gt4py/next/common.py | 20 ++- src/gt4py/next/constructors.py | 2 + src/gt4py/next/embedded/nd_array_field.py | 30 +++-- src/gt4py/next/ffront/decorator.py | 55 +++++--- src/gt4py/next/ffront/fbuiltins.py | 117 +++++++++++------- src/gt4py/next/iterator/embedded.py | 12 ++ tests/next_tests/exclusion_matrices.py | 12 ++ .../ffront_tests/ffront_test_utils.py | 19 ++- .../ffront_tests/test_arg_call_interface.py | 4 + .../ffront_tests/test_execution.py | 66 ++++++++-- .../ffront_tests/test_external_local_field.py | 3 + .../ffront_tests/test_gt4py_builtins.py | 7 ++ .../test_math_builtin_execution.py | 3 + .../ffront_tests/test_math_unary_builtins.py | 17 ++- .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_icon_like_scan.py | 3 + .../ffront_tests/test_laplacian.py | 4 + .../embedded_tests/test_nd_array_field.py | 2 +- tests/next_tests/unit_tests/test_common.py | 16 +++ 21 files changed, 293 insertions(+), 112 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7690ae583e..041448e17d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -343,7 +343,11 @@ markers = [ 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', 'uses_tuple_returns: tests that require backend support for tuple results', - 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields' + 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', + 'uses_cartesian_shift: tests that use a Cartesian connectivity', + 'uses_unstructured_shift: tests that use a unstructured connectivity', + 'uses_scan: tests that uses scan', + 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 7b318bc2de..79543a1849 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -446,6 +446,9 @@ def shape(self) -> tuple[int, ...]: def dtype(self) -> Any: ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: + ... + def __getitem__(self, item: Any) -> NDArrayObject: ... diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ffaa410563..66766be76b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -133,7 +133,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: else: raise IndexError("UnitRange index out of range") - def __and__(self, other: Set[Any]) -> UnitRange: + def __and__(self, other: Set[int]) -> UnitRange: if isinstance(other, UnitRange): start = max(self.start, other.start) stop = min(self.stop, other.stop) @@ -141,6 +141,16 @@ def __and__(self, other: Set[Any]) -> UnitRange: else: raise NotImplementedError("Can only find the intersection between UnitRange instances.") + def __le__(self, other: Set[int]): + if isinstance(other, UnitRange): + return self.start >= other.start and self.stop <= other.stop + elif len(self) == Infinity.positive(): + return False + else: + return Set.__le__(self, other) + + __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented + def __str__(self) -> str: return f"({self.start}:{self.stop})" @@ -486,6 +496,14 @@ def __neg__(self) -> Field: def __invert__(self) -> Field: """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod + def __eq__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool` + ... + + @abc.abstractmethod + def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool` + ... + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 30ef8452aa..42b0bcda90 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -82,6 +82,8 @@ def empty( (3, 3) """ dtype = core_defs.dtype(dtype) + if allocator is None and device is None: + device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) buffer = next_allocators.allocate( domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device ) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ea88948841..51e613ef81 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -135,25 +135,22 @@ def from_array( /, *, domain: common.DomainLike, - dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike + dtype: Optional[core_defs.DTypeLike] = None, ) -> NdArrayField: domain = common.domain(domain) xp = cls.array_ns - xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) + xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) - if dtype_like is not None: - assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type + if dtype is not None: + assert array.dtype.type == core_defs.dtype(dtype).scalar_type assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim - assert all( - len(r) == s or (s == 1 and r == common.UnitRange.infinity()) - for r, s in zip(domain.ranges, array.shape) - ) + assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape)) return cls(domain, array) @@ -194,6 +191,10 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __mod__ = __rmod__ = _make_builtin("mod", "mod") + __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool` + + __eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool` + def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_and", "logical_and")(self, other) @@ -285,7 +286,7 @@ def _np_cp_setitem( _nd_array_implementations = [np] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, eq=False) class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np @@ -298,7 +299,7 @@ class NumPyArrayField(NdArrayField): if cp: _nd_array_implementations.append(cp) - @dataclasses.dataclass(frozen=True) + @dataclasses.dataclass(frozen=True, eq=False) class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp @@ -310,7 +311,7 @@ class CuPyArrayField(NdArrayField): if jnp: _nd_array_implementations.append(jnp) - @dataclasses.dataclass(frozen=True) + @dataclasses.dataclass(frozen=True, eq=False) class JaxArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = jnp @@ -351,6 +352,13 @@ def _builtins_broadcast( NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) +def _astype(field: NdArrayField, type_: type) -> NdArrayField: + return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) + + +NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field + + def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 2d12331513..107415eb06 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,7 +32,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators +from gt4py.next import allocators as next_allocators, common from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import ( dialect_ast_enums, @@ -171,14 +171,14 @@ class Program: past_node: past.Program closure_vars: dict[str, Any] definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = None + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = None, + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, grid_type: Optional[GridType] = None, ) -> Program: source_def = SourceDefinition.from_function(definition) @@ -282,27 +282,23 @@ def itir(self) -> itir.FencilDefinition: ) def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: - if ( - self.backend is None and DEFAULT_BACKEND is None - ): # TODO(havogt): for now enable embedded execution by setting DEFAULT_BACKEND to None - self.definition(*args, **kwargs) - return - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if not self.backend: + if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.itir.id}': Using default ({DEFAULT_BACKEND}) backend." + f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." ) ) - backend = self.backend or DEFAULT_BACKEND - ppi.ensure_processor_kind(backend, ppi.ProgramExecutor) + self.definition(*rewritten_args, **kwargs) + return + + ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) if "debug" in kwargs: debug(self.itir) - backend( + self.backend( self.itir, *rewritten_args, *size_args, @@ -547,14 +543,14 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): foast_node: OperatorNodeT closure_vars: dict[str, Any] definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = None + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = None, + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, grid_type: Optional[GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, @@ -687,9 +683,9 @@ def __call__( # if we are reaching this from a program call. if "out" in kwargs: out = kwargs.pop("out") - if "offset_provider" in kwargs: + offset_provider = kwargs.pop("offset_provider", None) + if self.backend is not None: # "out" and "offset_provider" -> field_operator as program - offset_provider = kwargs.pop("offset_provider") args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given # deduce argument types @@ -705,13 +701,34 @@ def __call__( ) else: # "out" -> field_operator called from program in embedded execution - out.ndarray[:] = self.definition(*args, **kwargs).ndarray[:] + # TODO(egparedes): put offset_provider in ctxt var here when implementing remap + domain = kwargs.pop("domain", None) + res = self.definition(*args, **kwargs) + _tuple_assign_field( + out, res, domain=None if domain is None else common.domain(domain) + ) return else: # field_operator called from other field_operator in embedded execution + assert self.backend is None return self.definition(*args, **kwargs) +def _tuple_assign_field( + target: tuple[common.Field | tuple, ...] | common.Field, + source: tuple[common.Field | tuple, ...] | common.Field, + domain: Optional[common.Domain], +): + if isinstance(target, tuple): + if not isinstance(source, tuple): + raise RuntimeError(f"Cannot assign {source} to {target}.") + for t, s in zip(target, source): + _tuple_assign_field(t, s, domain) + else: + domain = domain or target.domain + target[domain] = source[domain] + + @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 7b96de8e89..706b6a4606 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -28,10 +28,12 @@ cast, ) +import numpy as np from numpy import float32, float64, int32, int64 -from gt4py._core import definitions as gt4py_defs -from gt4py.next.common import Dimension, DimensionKind, Field +from gt4py._core import definitions as core_defs +from gt4py.next import common +from gt4py.next.common import Dimension, Field # direct import for TYPE_BUILTINS from gt4py.next.ffront.experimental import as_offset # noqa F401 from gt4py.next.iterator import runtime from gt4py.next.type_system import type_specifications as ts @@ -40,7 +42,14 @@ PYTHON_TYPE_BUILTINS = [bool, int, float, tuple] PYTHON_TYPE_BUILTIN_NAMES = [t.__name__ for t in PYTHON_TYPE_BUILTINS] -TYPE_BUILTINS = [Field, Dimension, int32, int64, float32, float64] + PYTHON_TYPE_BUILTINS +TYPE_BUILTINS = [ + Field, + Dimension, + int32, + int64, + float32, + float64, +] + PYTHON_TYPE_BUILTINS TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] # Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not @@ -54,11 +63,11 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSpec], ...]: - if t is Field: + if t is common.Field: return ts.FieldType - elif t is Dimension: + elif t is common.Dimension: return ts.DimensionType - elif t is gt4py_defs.ScalarT: + elif t is core_defs.ScalarT: return ts.ScalarType elif t is type: return ( @@ -128,12 +137,8 @@ def __gt_type__(self) -> ts.FunctionType: ) -def builtin_function(fun: Callable[_P, _R]) -> BuiltInFunction[_R, _P]: - return BuiltInFunction(fun) - - -MaskT = TypeVar("MaskT", bound=Field) -FieldT = TypeVar("FieldT", bound=Union[Field, gt4py_defs.Scalar, Tuple]) +MaskT = TypeVar("MaskT", bound=common.Field) +FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) class WhereBuiltinFunction( @@ -153,55 +158,71 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: return super().__call__(mask, true_field, false_field) -@builtin_function +@BuiltInFunction def neighbor_sum( - field: Field, + field: common.Field, /, - axis: Dimension, -) -> Field: + axis: common.Dimension, +) -> common.Field: raise NotImplementedError() -@builtin_function +@BuiltInFunction def max_over( - field: Field, + field: common.Field, /, - axis: Dimension, -) -> Field: + axis: common.Dimension, +) -> common.Field: raise NotImplementedError() -@builtin_function +@BuiltInFunction def min_over( - field: Field, + field: common.Field, /, - axis: Dimension, -) -> Field: + axis: common.Dimension, +) -> common.Field: raise NotImplementedError() -@builtin_function -def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /) -> Field: - raise NotImplementedError() +@BuiltInFunction +def broadcast( + field: common.Field | core_defs.ScalarT, + dims: tuple[common.Dimension, ...], + /, +) -> common.Field: + assert core_defs.is_scalar_type( + field + ) # default implementation for scalars, Fields are handled via dispatch + return common.field( + np.asarray(field)[ + tuple([np.newaxis] * len(dims)) + ], # TODO(havogt) use FunctionField once available + domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinity()] * len(dims))), + ) @WhereBuiltinFunction def where( - mask: Field, - true_field: Field | gt4py_defs.ScalarT | Tuple, - false_field: Field | gt4py_defs.ScalarT | Tuple, + mask: common.Field, + true_field: common.Field | core_defs.ScalarT | Tuple, + false_field: common.Field | core_defs.ScalarT | Tuple, /, -) -> Field | Tuple: +) -> common.Field | Tuple: raise NotImplementedError() -@builtin_function +@BuiltInFunction def astype( - field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + value: Field | core_defs.ScalarT | Tuple, type_: type, /, -) -> Field | Tuple[Field, ...]: - raise NotImplementedError() +) -> Field | core_defs.ScalarT | Tuple: + if isinstance(value, tuple): + return tuple(astype(v, type_) for v in value) + # default implementation for scalars, Fields are handled via dispatch + assert core_defs.is_scalar_type(value) + return core_defs.dtype(type_).scalar_type(value) UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"] @@ -233,11 +254,14 @@ def astype( def _make_unary_math_builtin(name): - def impl(value: Field | gt4py_defs.ScalarT, /) -> Field | gt4py_defs.ScalarT: + def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: + # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) + # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: E800 # commented code + # return getattr(math, name)(value)# noqa: E800 # commented code raise NotImplementedError() impl.__name__ = name - globals()[name] = builtin_function(impl) + globals()[name] = BuiltInFunction(impl) for f in ( @@ -252,14 +276,17 @@ def impl(value: Field | gt4py_defs.ScalarT, /) -> Field | gt4py_defs.ScalarT: def _make_binary_math_builtin(name): def impl( - lhs: Field | gt4py_defs.ScalarT, - rhs: Field | gt4py_defs.ScalarT, + lhs: common.Field | core_defs.ScalarT, + rhs: common.Field | core_defs.ScalarT, /, - ) -> Field | gt4py_defs.ScalarT: - raise NotImplementedError() + ) -> common.Field | core_defs.ScalarT: + # default implementation for scalars, Fields are handled via dispatch + assert core_defs.is_scalar_type(lhs) + assert core_defs.is_scalar_type(rhs) + return getattr(np, name)(lhs, rhs) impl.__name__ = name - globals()[name] = builtin_function(impl) + globals()[name] = BuiltInFunction(impl) for f in BINARY_MATH_NUMBER_BUILTIN_NAMES: @@ -295,12 +322,12 @@ def impl( # guidelines for decision. @dataclasses.dataclass(frozen=True) class FieldOffset(runtime.Offset): - source: Dimension - target: tuple[Dimension] | tuple[Dimension, Dimension] + source: common.Dimension + target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] connectivity: Optional[Any] = None # TODO def __post_init__(self): - if len(self.target) == 2 and self.target[1].kind != DimensionKind.LOCAL: + if len(self.target) == 2 and self.target[1].kind != common.DimensionKind.LOCAL: raise ValueError("Second dimension in offset must be a local dimension.") def __gt_type__(self): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 674f99f61c..44294a3a71 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1093,6 +1093,12 @@ def __neg__(self) -> common.Field: def __invert__(self) -> common.Field: raise NotImplementedError() + def __eq__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + + def __ne__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() @@ -1194,6 +1200,12 @@ def __neg__(self) -> common.Field: def __invert__(self) -> common.Field: raise NotImplementedError() + def __eq__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + + def __ne__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index ddea04649f..249e17d358 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -98,6 +98,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_TUPLE_ARGS = "uses_tuple_args" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" +USES_CARTESIAN_SHIFT = "uses_cartesian_shift" +USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift" +USES_SCAN = "uses_scan" +CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" @@ -114,10 +118,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] +EMBEDDED_SKIP_LIST = [ + (USES_CARTESIAN_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_UNSTRUCTURED_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), +] #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) BACKEND_SKIP_TEST_MATRIX = { + None: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: GTFN_SKIP_TEST_LIST + [ (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 386e64451d..fb753bf169 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -53,6 +53,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non definitions.ProgramBackendId.GTFN_CPU, definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, + None, ] + OPTIONAL_PROCESSORS, ids=lambda p: p.short_id() if p is not None else "None", @@ -65,19 +66,15 @@ def fieldview_backend(request): Check ADR 15 for details on the test-exclusion matrices. """ backend_id = request.param - if backend_id is None: - backend = None - else: - backend = backend_id.load() - - for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( - backend_id, [] - ): - if request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=backend_id)) + backend = None if backend_id is None else backend_id.load() - backup_backend = decorator.DEFAULT_BACKEND + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) + backup_backend = decorator.DEFAULT_BACKEND decorator.DEFAULT_BACKEND = no_backend yield backend decorator.DEFAULT_BACKEND = backup_backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index deb1382dfb..6957e628bb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -158,6 +158,7 @@ def testee( ) +@pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator def test_call_scan_operator_from_field_operator(cartesian_case): @scan_operator(axis=KDim, forward=True, init=0.0) @@ -183,6 +184,7 @@ def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField: cases.verify(cartesian_case, testee, a, b, out=out, ref=expected) +@pytest.mark.uses_scan def test_call_scan_operator_from_program(cartesian_case): @scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: @@ -222,6 +224,7 @@ def testee( ) +@pytest.mark.uses_scan def test_scan_wrong_return_type(cartesian_case): with pytest.raises( errors.DSLError, @@ -239,6 +242,7 @@ def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2)) +@pytest.mark.uses_scan def test_scan_wrong_state_type(cartesian_case): with pytest.raises( errors.DSLError, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8787b7d7bc..8036c22670 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -76,6 +76,7 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases. cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b: (a, b)) +@pytest.mark.uses_cartesian_shift def test_cartesian_shift(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IJKField) -> cases.IJKField: @@ -87,6 +88,7 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, a, out=out, ref=a[1:]) +@pytest.mark.uses_unstructured_shift def test_unstructured_shift(unstructured_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.VField) -> cases.EField: @@ -99,6 +101,7 @@ def testee(a: cases.VField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator def composed_shift_unstructured_flat(inp: cases.VField) -> cases.CField: @@ -143,6 +146,7 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: ) +@pytest.mark.uses_cartesian_shift def test_fold_shifts(cartesian_case): # noqa: F811 # fixtures """Shifting the result of an addition should work.""" @@ -206,6 +210,7 @@ def testee(a: int32) -> cases.VField: @pytest.mark.uses_index_fields +@pytest.mark.uses_cartesian_shift def test_scalar_arg_with_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IJKField, b: int32) -> cases.IJKField: @@ -246,6 +251,7 @@ def testee(size: gtx.IndexType, out: gtx.Field[[IDim], gtx.IndexType]): ) +@pytest.mark.uses_scan def test_scalar_scan(cartesian_case): # noqa: F811 # fixtures @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) def testee_scan(state: float, qc_in: float, scalar: float) -> float: @@ -264,6 +270,7 @@ def testee(qc: cases.IKFloatField, scalar: float): cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) +@pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator def test_tuple_scalar_scan(cartesian_case): # noqa: F811 # fixtures @gtx.scan_operator(axis=KDim, forward=True, init=0.0) @@ -285,6 +292,7 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_scan @pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) @@ -363,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) - b_asint = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32)) + b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -483,6 +491,7 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator @@ -504,6 +513,7 @@ def testee(a: cases.EField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift @pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_`ing of inner reduce op.") def test_nested_reduction_shift_first(unstructured_case): @gtx.field_operator @@ -524,6 +534,7 @@ def testee(inp: cases.EField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_tuple_returns def test_tuple_return_2(unstructured_case): @gtx.field_operator @@ -543,6 +554,7 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_constant_fields def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): @gtx.field_operator @@ -572,6 +584,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I ) +@pytest.mark.uses_scan @pytest.mark.parametrize("forward", [True, False]) def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 @@ -592,6 +605,7 @@ def simple_scan_operator(carry: float) -> float: cases.verify(cartesian_case, simple_scan_operator, out=out, ref=expected) +@pytest.mark.uses_scan @pytest.mark.uses_lift_expressions def test_solve_triag(cartesian_case): if cartesian_case.backend in [ @@ -680,6 +694,7 @@ def testee( ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): @gtx.field_operator @@ -698,6 +713,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: ) +@pytest.mark.uses_scan def test_ternary_scan(cartesian_case): if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @@ -720,6 +736,7 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) +@pytest.mark.uses_scan @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: @@ -745,13 +762,14 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): cartesian_case, testee, ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)), - comparison=lambda ref, out: np.all(out[0] == ref[0]) - and np.all(out[1][0] == ref[1][0]) - and np.all(out[1][1] == ref[1][1]), + comparison=lambda ref, out: np.all(np.asarray(out[0]) == ref[0]) + and np.all(np.asarray(out[1][0]) == ref[1][0]) + and np.all(np.asarray(out[1][1]) == ref[1][1]), ) @pytest.mark.uses_tuple_args +@pytest.mark.uses_scan def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] @@ -824,7 +842,10 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - cases.verify(cartesian_case, program_domain, a, out, inout=out[1:9], ref=a[1:9] * 2) + ref = out.ndarray.copy() # ensure we are not overwriting out outside of the domain + ref[1:9] = a[1:9] * 2 + + cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) def test_domain_input_bounds(cartesian_case): @@ -855,6 +876,9 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() + ref = out.ndarray.copy() + ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 + cases.verify( cartesian_case, program_domain, @@ -862,8 +886,8 @@ def program_domain( out, lower_i, upper_i, - inout=out[lower_i : int(upper_i / 2)], - ref=inp[lower_i : int(upper_i / 2)] * 2, + inout=out, + ref=ref, ) @@ -895,6 +919,11 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() + ref = out.ndarray.copy() + ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( + a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 + ) + cases.verify( cartesian_case, program_domain, @@ -904,8 +933,8 @@ def program_domain( upper_i, lower_j, upper_j, - inout=out[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j], - ref=a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2, + inout=out, + ref=ref, ) @@ -930,6 +959,11 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() + ref0 = out0.ndarray.copy() + ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] + ref1 = out1.ndarray.copy() + ref1[1:9, 4:6] = inp1[1:9, 4:6] + cases.verify( cartesian_case, program_domain_tuple, @@ -937,11 +971,12 @@ def program_domain_tuple( inp1, out0, out1, - inout=(out0[1:9, 4:6], out1[1:9, 4:6]), - ref=(inp0[1:9, 4:6] + inp1[1:9, 4:6], inp1[1:9, 4:6]), + inout=(out0, out1), + ref=(ref0, ref1), ) +@pytest.mark.uses_cartesian_shift def test_where_k_offset(cartesian_case): @gtx.field_operator def fieldop_where_k_offset( @@ -1079,6 +1114,13 @@ def _invalid_unpack() -> tuple[int32, float64, int32]: def test_constant_closure_vars(cartesian_case): + if cartesian_case.backend is None: + # >>> field = gtx.zeros(domain) + # >>> np.int32(1)*field # steals the buffer from the field + # array([0.]) + + # TODO(havogt): remove `__array__`` from `NdArrayField` + pytest.xfail("Bug: Binary operation between np datatype and Field returns ndarray.") from gt4py.eve.utils import FrozenNamespace constants = FrozenNamespace( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 04b27c6c17..5135b3d47a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -26,6 +26,9 @@ ) +pytestmark = pytest.mark.uses_unstructured_shift + + def test_external_local_field(unstructured_case): @gtx.field_operator def testee( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 8213f54a45..1eba95e880 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -39,6 +39,7 @@ ) +@pytest.mark.uses_unstructured_shift @pytest.mark.parametrize( "strategy", [cases.UniqueInitializer(1), cases.UniqueInitializer(-100)], @@ -65,6 +66,7 @@ def testee(edge_f: cases.EField) -> cases.VField: cases.verify(unstructured_case, testee, inp, ref=ref, out=out) +@pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: @@ -77,6 +79,7 @@ def minover(edge_f: cases.EField) -> cases.VField: ) +@pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: @@ -93,6 +96,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_constant_fields def test_reduction_expression_in_call(unstructured_case): @gtx.field_operator @@ -113,6 +117,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: @@ -191,6 +196,7 @@ def broadcast_two_fields(inp1: cases.IField, inp2: gtx.Field[[JDim], int32]) -> ) +@pytest.mark.uses_cartesian_shift def test_broadcast_shifted(cartesian_case): @gtx.field_operator def simple_broadcast(inp: cases.IField) -> cases.IJField: @@ -249,6 +255,7 @@ def conditional_promotion(a: cases.IFloatField) -> cases.IFloatField: ) +@pytest.mark.uses_cartesian_shift def test_conditional_shifted(cartesian_case): @gtx.field_operator def conditional_shifted( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index a5d2b92719..a1839b8e17 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -116,6 +116,9 @@ def make_builtin_field_operator(builtin_name: str): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inputs): + if cartesian_case.backend is None: + # TODO(havogt) find a way that works for embedded + pytest.xfail("Test does not have a field view program.") if builtin_name == "gamma": # numpy has no gamma function ref_impl: Callable = np.vectorize(math.gamma) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 5a277f9440..59e11a7de8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -173,17 +173,14 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: def test_unary_not(cartesian_case): - @gtx.field_operator - def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: - return not inp1 + pytest.xfail( + "We accidentally supported `not` on fields. This is wrong, we should raise an error." + ) + with pytest.raises: # TODO `not` on a field should be illegal - size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, not_fieldop, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - out = cases.allocate(cartesian_case, not_fieldop, cases.RETURN)() - cases.verify(cartesian_case, not_fieldop, inp1, out=out, ref=~inp1) + @gtx.field_operator + def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: + return not inp1 # Trig builtins diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 7a1c827a0d..545abd2825 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -51,6 +51,7 @@ def test_identity_fo_execution(cartesian_case, identity_def): ) +@pytest.mark.uses_cartesian_shift def test_shift_by_one_execution(cartesian_case): @gtx.field_operator def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: @@ -230,6 +231,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): assert re.search(msg, exc_info.value.__cause__.args[0]) is not None +@pytest.mark.checks_specific_error def test_dimensions_domain(cartesian_case): @gtx.field_operator def empty_domain_fieldop(a: cases.IJField): @@ -246,4 +248,4 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): ValueError, match=(r"Dimensions in out field and field domain are not equivalent"), ): - empty_domain_program(a, out_field, offset_provider={}) + cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 108ee25862..eaae9a2a3e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -25,6 +25,9 @@ ) +pytestmark = pytest.mark.uses_unstructured_shift + + Cell = gtx.Dimension("Cell") KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) Koff = gtx.FieldOffset("Koff", KDim, (KDim,)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index d275a977dd..9a1e968de0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np +import pytest import gt4py.next as gtx @@ -23,6 +24,9 @@ ) +pytestmark = pytest.mark.uses_cartesian_shift + + @gtx.field_operator def lap(in_field: gtx.Field[[IDim, JDim], "float"]) -> gtx.Field[[IDim, JDim], "float"]: return ( diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 49aeece87e..00dbf68274 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -259,7 +259,7 @@ def test_mixed_fields(product_nd_array_implementation): def test_non_dispatched_function(): - @fbuiltins.builtin_function + @fbuiltins.BuiltInFunction def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: return a * b + c diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 31e35221ab..84008eb99c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import operator from typing import Optional, Pattern import pytest @@ -150,6 +151,21 @@ def test_mixed_infinity_range(): assert len(mixed_inf_range) == Infinity.positive() +@pytest.mark.parametrize( + "op, rng1, rng2, expected", + [ + (operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True), + (operator.le, UnitRange(-1, 2), {-1, 0, 1}, True), + (operator.le, UnitRange(-1, 2), {-1, 0}, False), + (operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True), + (operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True), + (operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False), + ], +) +def test_range_comparison(op, rng1, rng2, expected): + assert op(rng1, rng2) == expected + + @pytest.mark.parametrize( "named_rng_like", [ From 42912cc9d14e409801c1c71fc99a98f46e7c4a1b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 20 Nov 2023 11:13:36 +0100 Subject: [PATCH 38/67] feat[next] Enable GPU backend tests (#1357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - connectivities are implicitly copied to GPU if they are not already on GPU, this might be removed later - changes to cases: ensure we don't pass arrays to ConstInitializer --------- Co-authored-by: Rico Häuselmann --- src/gt4py/next/embedded/nd_array_field.py | 5 +- .../codegens/gtfn/codegen.py | 59 +++++++------- .../next/program_processors/runners/gtfn.py | 30 +++++-- tests/next_tests/exclusion_matrices.py | 5 ++ tests/next_tests/integration_tests/cases.py | 18 ++++- .../ffront_tests/ffront_test_utils.py | 1 + .../ffront_tests/test_execution.py | 33 ++++---- .../ffront_tests/test_external_local_field.py | 8 +- .../ffront_tests/test_gt4py_builtins.py | 18 ++--- .../test_math_builtin_execution.py | 4 +- .../ffront_tests/test_math_unary_builtins.py | 35 +++----- .../ffront_tests/test_program.py | 2 +- .../ffront_tests/test_icon_like_scan.py | 79 ++++++++++++------- .../ffront_tests/test_laplacian.py | 2 +- tests/next_tests/unit_tests/conftest.py | 1 + tox.ini | 2 +- 16 files changed, 176 insertions(+), 126 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 51e613ef81..9357570b05 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -121,7 +121,10 @@ def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: - return np.asarray(self._ndarray, dtype) + if self.array_ns == cp: + return np.asarray(cp.asnumpy(self._ndarray), dtype) + else: + return np.asarray(self._ndarray, dtype) @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 645d1f742f..23165854de 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -179,6 +179,10 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) + def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): + expr_ = "return " + self.visit(node.expr) + return self.generic_visit(node, expr_=expr_) + FunctionDefinition = as_mako( """ struct ${id} { @@ -206,24 +210,6 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) - def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): - expr_ = "return " + self.visit(node.expr) - return self.generic_visit(node, expr_=expr_) - - def visit_FencilDefinition( - self, node: gtfn_ir.FencilDefinition, **kwargs: Any - ) -> Union[str, Collection[str]]: - self.is_cartesian = node.grid_type == common.GridType.CARTESIAN - self.user_defined_function_ids = list( - str(fundef.id) for fundef in node.function_definitions - ) - return self.generic_visit( - node, - grid_type_str=self._grid_type_str[node.grid_type], - block_sizes=self._block_sizes(node.offset_definitions), - **kwargs, - ) - def visit_TemporaryAllocation(self, node, **kwargs): # TODO(tehrengruber): Revisit. We are currently converting an itir.NamedRange with # start and stop values into an gtfn_ir.(Cartesian|Unstructured)Domain with @@ -244,6 +230,20 @@ def visit_TemporaryAllocation(self, node, **kwargs): "auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes});" ) + def visit_FencilDefinition( + self, node: gtfn_ir.FencilDefinition, **kwargs: Any + ) -> Union[str, Collection[str]]: + self.is_cartesian = node.grid_type == common.GridType.CARTESIAN + self.user_defined_function_ids = list( + str(fundef.id) for fundef in node.function_definitions + ) + return self.generic_visit( + node, + grid_type_str=self._grid_type_str[node.grid_type], + block_sizes=self._block_sizes(node.offset_definitions), + **kwargs, + ) + FencilDefinition = as_mako( """ #include @@ -277,16 +277,19 @@ def visit_TemporaryAllocation(self, node, **kwargs): ) def _block_sizes(self, offset_definitions: list[gtfn_ir.TagDefinition]) -> str: - block_dims = [] - block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) - for i, tag in enumerate(offset_definitions): - if tag.alias is None: - block_dims.append( - f"gridtools::meta::list<{tag.name.id}_t, " - f"gridtools::integral_constant>" - ) - sizes_str = ",\n".join(block_dims) - return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + if self.is_cartesian: + block_dims = [] + block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) + for i, tag in enumerate(offset_definitions): + if tag.alias is None: + block_dims.append( + f"gridtools::meta::list<{tag.name.id}_t, " + f"gridtools::integral_constant>" + ) + sizes_str = ",\n".join(block_dims) + return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + else: + return "using block_sizes_t = gridtools::meta::list>, gridtools::meta::list>>;" @classmethod def apply(cls, root: Any, **kwargs: Any) -> str: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7233e7a893..5d4b450d39 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import functools +import warnings from typing import Any import numpy.typing as npt @@ -42,12 +44,14 @@ def convert_arg(arg: Any) -> Any: return arg -def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram: +def convert_args( + inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.CompiledProgram: def decorated_program( *args, offset_provider: dict[str, common.Connectivity | common.Dimension] ): converted_args = [convert_arg(arg) for arg in args] - conn_args = extract_connectivity_args(offset_provider) + conn_args = extract_connectivity_args(offset_provider, device) return inp( *converted_args, *conn_args, @@ -56,8 +60,22 @@ def decorated_program( return decorated_program +def _ensure_is_on_device( + connectivity_arg: npt.NDArray, device: core_defs.DeviceType +) -> npt.NDArray: + if device == core_defs.DeviceType.CUDA: + import cupy as cp + + if not isinstance(connectivity_arg, cp.ndarray): + warnings.warn( + "Copying connectivity to device. For performance make sure connectivity is provided on device." + ) + return cp.asarray(connectivity_arg) + return connectivity_arg + + def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType ) -> list[tuple[npt.NDArray, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] @@ -67,7 +85,9 @@ def extract_connectivity_args( raise NotImplementedError( "Only `NeighborTable` connectivities implemented at this point." ) - args.append((conn.table, tuple([0] * 2))) + # copying to device here is a fallback for easy testing and might be removed later + conn_arg = _ensure_is_on_device(conn.table, device) + args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass else: @@ -126,7 +146,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: translation=GTFN_GPU_TRANSLATION_STEP, bindings=nanobind.bind_source, compilation=GTFN_DEFAULT_COMPILE_STEP, - decoration=convert_args, + decoration=functools.partial(convert_args, device=core_defs.DeviceType.CUDA), ) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 249e17d358..ef30a61687 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -50,6 +50,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU_WITH_TEMPORARIES = ( "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" ) + GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -148,6 +149,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 634d85e64c..730ce18fd5 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -25,6 +25,7 @@ import pytest import gt4py.next as gtx +from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self from gt4py.next import common, constructors @@ -73,7 +74,7 @@ E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim)) -ScalarValue: TypeAlias = np.int32 | np.int64 | np.float32 | np.float64 | np.generic +ScalarValue: TypeAlias = core_defs.Scalar FieldValue: TypeAlias = gtx.Field FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...] FieldViewInout: TypeAlias = FieldValue | tuple["FieldViewInout", ...] @@ -117,12 +118,19 @@ def from_case( return self -@dataclasses.dataclass +@dataclasses.dataclass(init=False) class ConstInitializer(DataInitializer): """Initialize with a given value across the coordinate space.""" value: ScalarValue + def __init__(self, value: ScalarValue): + if not core_defs.is_scalar_type(value): + raise ValueError( + "`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead." + ) + self.value = value + @property def scalar_value(self) -> ScalarValue: return self.value @@ -460,7 +468,7 @@ def verify_with_default_data( ``comparison(ref, )`` and should return a boolean. """ inps, kwfields = get_default_data(case, fieldop) - ref_args = tuple(i.ndarray if hasattr(i, "ndarray") else i for i in inps) + ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps) verify( case, fieldop, @@ -598,3 +606,7 @@ class Case: offset_provider: dict[str, common.Connectivity | gtx.Dimension] default_sizes: dict[gtx.Dimension, int] grid_type: common.GridType + + @property + def as_field(self): + return constructors.as_field.partial(allocator=self.backend) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index fb753bf169..01c78cf950 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -53,6 +53,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non definitions.ProgramBackendId.GTFN_CPU, definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, + pytest.param(definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu), None, ] + OPTIONAL_PROCESSORS, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8036c22670..fe18bda9e3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -371,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32)) - b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32)) + a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32)) + b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -589,7 +589,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1) - out = gtx.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) + out = cartesian_case.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) if not forward: expected = np.flip(expected) @@ -610,6 +610,7 @@ def simple_scan_operator(carry: float) -> float: def test_solve_triag(cartesian_case): if cartesian_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -723,8 +724,8 @@ def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 k_size = cartesian_case.default_sizes[KDim] - a = gtx.as_field([KDim], 4.0 * np.ones((k_size,))) - out = gtx.as_field([KDim], np.zeros((k_size,))) + a = cartesian_case.as_field([KDim], 4.0 * np.ones((k_size,))) + out = cartesian_case.as_field([KDim], np.zeros((k_size,))) cases.verify( cartesian_case, @@ -773,16 +774,19 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] - inp1 = gtx.as_field([KDim], np.ones((k_size,))) - inp2 = gtx.as_field([KDim], np.arange(0.0, k_size, 1)) - out = gtx.as_field([KDim], np.zeros((k_size,))) + + inp1_np = np.ones((k_size,)) + inp2_np = np.arange(0.0, k_size, 1) + inp1 = cartesian_case.as_field([KDim], inp1_np) + inp2 = cartesian_case.as_field([KDim], inp2_np) + out = cartesian_case.as_field([KDim], np.zeros((k_size,))) def prev_levels_iterator(i): return range(i + 1) expected = np.asarray( [ - reduce(lambda prev, i: prev + inp1[i] + inp2[i], prev_levels_iterator(i), init) + reduce(lambda prev, i: prev + inp1_np[i] + inp2_np[i], prev_levels_iterator(i), init) for i in range(k_size) ] ) @@ -842,7 +846,7 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.ndarray.copy() # ensure we are not overwriting out outside of the domain + ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain ref[1:9] = a[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) @@ -851,6 +855,7 @@ def program_domain(a: cases.IField, out: cases.IField): def test_domain_input_bounds(cartesian_case): if cartesian_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -876,7 +881,7 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() - ref = out.ndarray.copy() + ref = np.asarray(out).copy() ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 cases.verify( @@ -919,7 +924,7 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.ndarray.copy() + ref = np.asarray(out).copy() ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 ) @@ -959,9 +964,9 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() - ref0 = out0.ndarray.copy() + ref0 = np.asarray(out0).copy() ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] - ref1 = out1.ndarray.copy() + ref1 = np.asarray(out1).copy() ref1[1:9, 4:6] = inp1[1:9, 4:6] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 5135b3d47a..05adc63a45 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -38,7 +38,9 @@ def testee( inp * ones(V2E), axis=V2EDim ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported - inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) + inp = unstructured_case.as_field( + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() cases.verify( @@ -59,7 +61,9 @@ def test_external_local_field_only(unstructured_case): def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32]: return neighbor_sum(inp, axis=V2EDim) - inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) + inp = unstructured_case.as_field( + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + ) cases.verify( unstructured_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 1eba95e880..8bc325d276 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -48,6 +48,7 @@ def test_maxover_execution_(unstructured_case, strategy): if unstructured_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -142,10 +143,7 @@ def conditional_nested_tuple( return where(mask, ((a, b), (b, a)), ((5.0, 7.0), (7.0, 5.0))) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional_nested_tuple, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=size)) a = cases.allocate(cartesian_case, conditional_nested_tuple, "a")() b = cases.allocate(cartesian_case, conditional_nested_tuple, "b")() @@ -216,10 +214,7 @@ def conditional( return where(mask, a, b) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional, "a")() b = cases.allocate(cartesian_case, conditional, "b")() out = cases.allocate(cartesian_case, conditional, cases.RETURN)() @@ -233,10 +228,7 @@ def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases return where(mask, a, 10.0) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional_promotion, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_promotion, "a")() out = cases.allocate(cartesian_case, conditional_promotion, cases.RETURN)() @@ -274,7 +266,7 @@ def conditional_program( conditional_shifted(mask, a, b, out=out) size = cartesian_case.default_sizes[IDim] + 1 - mask = gtx.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_program, "a").extend({IDim: (0, 1)})() b = cases.allocate(cartesian_case, conditional_program, "b").extend({IDim: (0, 1)})() out = cases.allocate(cartesian_case, conditional_shifted, cases.RETURN)() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index a1839b8e17..937b05e087 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -125,9 +125,9 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp else: ref_impl: Callable = getattr(np, builtin_name) - inps = [gtx.as_field([IDim], np.asarray(input)) for input in inputs] + inps = [cartesian_case.as_field([IDim], np.asarray(input)) for input in inputs] expected = ref_impl(*inputs) - out = gtx.as_field([IDim], np.zeros_like(expected)) + out = cartesian_case.as_field([IDim], np.zeros_like(expected)) builtin_field_op = make_builtin_field_operator(builtin_name).with_backend( cartesian_case.backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 59e11a7de8..8660ecfdbd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -72,6 +72,7 @@ def test_floordiv(cartesian_case): gtfn.run_gtfn, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn_gpu, ]: pytest.xfail( "FloorDiv not yet supported." @@ -90,7 +91,7 @@ def test_mod(cartesian_case): def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 - inp1 = gtx.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) + inp1 = cartesian_case.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) @@ -102,13 +103,8 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie return inp1 ^ inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, binary_xor, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, binary_xor, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, binary_xor, cases.RETURN)() cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2) @@ -119,13 +115,8 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: return inp1 & inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, bit_and, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, bit_and, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, bit_and, cases.RETURN)() cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2) @@ -136,13 +127,8 @@ def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: return inp1 | inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, bit_or, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, bit_or, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, bit_or, cases.RETURN)() cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1 | inp2) @@ -164,10 +150,7 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: return ~inp1 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, tilde_fieldop, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, tilde_fieldop, cases.RETURN)() cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 545abd2825..b82cae25a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -215,7 +215,7 @@ def prog( def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) - inp = gtx.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) + inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() with pytest.raises(TypeError) as exc_info: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index eaae9a2a3e..cd948ffa02 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,8 +18,11 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.program_processors.runners import gtfn, roundtrip +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import Cell, KDim, Koff from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, ) @@ -190,80 +193,97 @@ def reference( @pytest.fixture -def test_setup(): +def test_setup(fieldview_backend): + test_case = cases.Case( + fieldview_backend, + offset_provider={"Koff": KDim}, + default_sizes={Cell: 14, KDim: 10}, + grid_type=common.GridType.UNSTRUCTURED, + ) + @dataclass(frozen=True) class setup: - cell_size = 14 - k_size = 10 - z_alpha = gtx.as_field( + case: cases.Case = test_case + cell_size = case.default_sizes[Cell] + k_size = case.default_sizes[KDim] + z_alpha = case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = gtx.as_field( + z_beta = case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): - if fieldview_backend in [ +def test_solve_nonhydro_stencil_52_like_z_q(test_setup): + if test_setup.case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") - solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)( + cases.verify( + test_setup.case, + solve_nonhydro_stencil_52_like_z_q, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.z_q_out, - offset_provider={"Koff": KDim}, + ref=test_setup.z_q_ref, + inout=test_setup.z_q_out, + comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if fieldview_backend == roundtrip.backend: + if test_setup.case.backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( + cases.verify( + test_setup.case, + solve_nonhydro_stencil_52_like_z_q_tup, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.z_q_out, - offset_provider={"Koff": KDim}, + ref=test_setup.z_q_ref, + inout=test_setup.z_q_out, + comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) - @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( + + cases.run( + test_setup.case, + solve_nonhydro_stencil_52_like, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.dummy, - offset_provider={"Koff": KDim}, ) assert np.allclose(test_setup.z_q_ref, test_setup.z_q) @@ -271,18 +291,19 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == roundtrip.backend: + if test_setup.case.backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( + cases.run( + test_setup.case, + solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, - offset_provider={"Koff": KDim}, ) assert np.allclose(test_setup.z_q_ref, test_setup.z_q) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 9a1e968de0..4f4d4969a9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -86,5 +86,5 @@ def test_ffront_lap(cartesian_case): in_field, out_field, inout=out_field[2:-2, 2:-2], - ref=lap_ref(lap_ref(np.asarray(in_field.ndarray))), + ref=lap_ref(lap_ref(in_field.array_ns.asarray(in_field.ndarray))), ) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index b43eeb3f91..372062d08a 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,6 +60,7 @@ def lift_mode(request): (definitions.ProgramBackendId.GTFN_CPU, True), (definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), (definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), + # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (definitions.ProgramFormatterId.LISP_FORMATTER, False), (definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), diff --git a/tox.ini b/tox.ini index 5b644e7d97..44dc912c8a 100644 --- a/tox.ini +++ b/tox.ini @@ -84,7 +84,7 @@ commands = nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests nomesh-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist + # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next [testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] From 6375445b4edd93ab734124325c1adfae42b2bb84 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 21 Nov 2023 12:50:30 +0100 Subject: [PATCH 39/67] feat[next] Embedded field remove __array__ (#1366) Add `.asnumpy` to `Field`. Implicit conversion via `__array__` creates a problem, because expression `np.float*field` will return ndarray instead of field, because `np.float`'s multiply operator will `asarray(rhs)`. Update all tests to do an explicit conversion to ndarray if needed. --------- Co-authored-by: nfarabullini --- src/gt4py/next/common.py | 4 ++ src/gt4py/next/embedded/nd_array_field.py | 6 +-- src/gt4py/next/iterator/embedded.py | 6 +++ src/gt4py/next/utils.py | 40 +++++++++++++++- tests/next_tests/integration_tests/cases.py | 13 +++--- .../ffront_tests/test_arg_call_interface.py | 14 +++--- .../ffront_tests/test_execution.py | 46 +++++++++---------- .../ffront_tests/test_gt4py_builtins.py | 21 ++++++--- .../test_math_builtin_execution.py | 2 +- .../ffront_tests/test_program.py | 10 ++-- .../ffront_tests/test_scalar_if.py | 8 ++-- .../iterator_tests/test_builtins.py | 12 ++--- .../iterator_tests/test_conditional.py | 4 +- .../iterator_tests/test_constant.py | 4 +- .../test_horizontal_indirection.py | 4 +- .../iterator_tests/test_implicit_fencil.py | 6 +-- .../feature_tests/iterator_tests/test_scan.py | 2 +- .../test_strided_offset_provider.py | 4 +- .../iterator_tests/test_trivial.py | 8 ++-- .../iterator_tests/test_tuple.py | 26 +++++------ .../feature_tests/test_util_cases.py | 18 ++++---- .../ffront_tests/test_icon_like_scan.py | 10 ++-- .../iterator_tests/test_anton_toy.py | 2 +- .../iterator_tests/test_column_stencil.py | 19 ++++---- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++-------- .../iterator_tests/test_hdiff.py | 2 +- .../iterator_tests/test_vertical_advection.py | 2 +- .../test_with_toy_connectivity.py | 26 +++++------ .../otf_tests/test_gtfn_workflow.py | 2 +- .../embedded_tests/test_nd_array_field.py | 4 +- 30 files changed, 209 insertions(+), 156 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 66766be76b..51ad14f22d 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -467,6 +467,10 @@ def ndarray(self) -> core_defs.NDArrayObject: def __str__(self) -> str: return f"⟨{self.domain!s} → {self.dtype}⟩" + @abc.abstractmethod + def asnumpy(self) -> np.ndarray: + ... + @abc.abstractmethod def remap(self, index_field: Field) -> Field: ... diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9357570b05..a843772a20 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -120,11 +120,11 @@ def __gt_origin__(self) -> tuple[int, ...]: def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray - def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: + def asnumpy(self) -> np.ndarray: if self.array_ns == cp: - return np.asarray(cp.asnumpy(self._ndarray), dtype) + return cp.asnumpy(self._ndarray) else: - return np.asarray(self._ndarray, dtype) + return np.asarray(self._ndarray) @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 44294a3a71..9000b00d8f 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1068,6 +1068,9 @@ def dtype(self) -> core_defs.Int32DType: def ndarray(self) -> core_defs.NDArrayObject: raise AttributeError("Cannot get `ndarray` of an infinite Field.") + def asnumpy(self) -> np.ndarray: + raise NotImplementedError() + def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1180,6 +1183,9 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]: def ndarray(self) -> core_defs.NDArrayObject: raise AttributeError("Cannot get `ndarray` of an infinite Field.") + def asnumpy(self) -> np.ndarray: + raise NotImplementedError() + def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 006b3057b0..baae8361c5 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar, TypeGuard, TypeVar +import functools +from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast + +import numpy as np + +from gt4py.next import common class RecursionGuard: @@ -53,6 +58,39 @@ def __exit__(self, *exc): _T = TypeVar("_T") +_P = ParamSpec("_P") +_R = TypeVar("_R") + def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) + + +def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: + """Apply `fun` to each entry of (possibly nested) tuples. + + Examples: + >>> tree_map(lambda x: x + 1)(((1, 2), 3)) + ((2, 3), 4) + + >>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6)) + ((5, 7), 9) + """ + + @functools.wraps(fun) + def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: + if isinstance(args[0], tuple): + assert all(isinstance(arg, tuple) and len(args[0]) == len(arg) for arg in args) + return tuple(impl(*arg) for arg in zip(*args)) + + return fun( + *cast(_P.args, args) + ) # mypy doesn't understand that `args` at this point is of type `_P.args` + + return impl + + +# TODO(havogt): consider moving to module like `field_utils` +@tree_map +def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: + return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 730ce18fd5..7ef724ee2f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,7 +28,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors +from gt4py.next import common, constructors, utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -435,14 +435,13 @@ def verify( run(case, fieldview_prog, *args, offset_provider=offset_provider) out_comp = out or inout - out_comp_str = str(out_comp) assert out_comp is not None - if hasattr(out_comp, "ndarray"): - out_comp_str = str(out_comp.ndarray) - assert comparison(ref, out_comp), ( + out_comp_ndarray = utils.asnumpy(out_comp) + ref_ndarray = utils.asnumpy(ref) + assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" - f"\tref = {ref}\n\tout = {out_comp_str}" + f"\tref = {ref_ndarray}\n\tout = {str(out_comp_ndarray)}" ) @@ -468,7 +467,7 @@ def verify_with_default_data( ``comparison(ref, )`` and should return a boolean. """ inps, kwfields = get_default_data(case, fieldop) - ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps) + ref_args = tuple(i.asnumpy() if common.is_field(i) else i for i in inps) verify( case, fieldop, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 6957e628bb..6293ff76bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -63,9 +63,9 @@ def testee(a: IField, b: IField, c: IField) -> IField: *pos_args, **kw_args, out=out, offset_provider=cartesian_case.offset_provider ) - expected = np.asarray(args["a"]) * 2 * np.asarray(args["b"]) - np.asarray(args["c"]) + expected = args["a"] * 2 * args["b"] - args["c"] - assert np.allclose(out, expected) + assert np.allclose(out.asnumpy(), expected.asnumpy()) @pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "out"))) @@ -89,9 +89,9 @@ def testee(a: IField, b: IField, out: IField): *pos_args, **kw_args, offset_provider=cartesian_case.offset_provider ) - expected = np.asarray(args["a"]) + 2 * np.asarray(args["b"]) + expected = args["a"] + 2 * args["b"] - assert np.allclose(args["out"], expected) + assert np.allclose(args["out"].asnumpy(), expected.asnumpy()) def test_call_field_operator_from_field_operator(cartesian_case): @@ -177,9 +177,7 @@ def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField: a, b, out = ( cases.allocate(cartesian_case, testee, name)() for name in ("a", "b", cases.RETURN) ) - expected = (1.0 + 3.0 + 5.0 + 7.0) * np.add.accumulate( - np.asarray(a) + 2.0 * np.asarray(b), axis=2 - ) + expected = (1.0 + 3.0 + 5.0 + 7.0) * np.add.accumulate(a.asnumpy() + 2.0 * b.asnumpy(), axis=2) cases.verify(cartesian_case, testee, a, b, out=out, ref=expected) @@ -210,7 +208,7 @@ def testee( for name in ("out1", "out2", "out3", "out4") ) - ref = np.add.accumulate(np.asarray(a) + 2 * np.asarray(b), axis=2) + ref = np.add.accumulate(a.asnumpy() + 2 * b.asnumpy(), axis=2) cases.verify( cartesian_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index fe18bda9e3..1f3b54d6f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -371,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32)) - b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32)) + a_asint = cartesian_case.as_field([IDim], a.asnumpy().astype(int32)) + b_asint = cartesian_case.as_field([IDim], b.asnumpy().astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -384,7 +384,10 @@ def cast_nested_tuple( a_asint, b_asint, out=out_tuple, - ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), + ref=( + np.full_like(a.asnumpy(), True, dtype=bool), + np.full_like(b.asnumpy(), True, dtype=bool), + ), ) cases.verify( @@ -396,9 +399,9 @@ def cast_nested_tuple( b_asint, out=out_nested_tuple, ref=( - np.full_like(a, True, dtype=bool), - np.full_like(a, True, dtype=bool), - np.full_like(b, True, dtype=bool), + np.full_like(a.asnumpy(), True, dtype=bool), + np.full_like(a.asnumpy(), True, dtype=bool), + np.full_like(b.asnumpy(), True, dtype=bool), ), ) @@ -473,7 +476,7 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD comparison=lambda out, ref: np.all(out == ref), ) - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) def test_nested_tuple_return(cartesian_case): @@ -846,8 +849,8 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain - ref[1:9] = a[1:9] * 2 + ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain + ref[1:9] = a.asnumpy()[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) @@ -881,8 +884,8 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() - ref = np.asarray(out).copy() - ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 + ref = out.asnumpy().copy() + ref[lower_i : int(upper_i / 2)] = inp.asnumpy()[lower_i : int(upper_i / 2)] * 2 cases.verify( cartesian_case, @@ -924,9 +927,9 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = np.asarray(out).copy() + ref = out.asnumpy().copy() ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( - a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 + a.asnumpy()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 ) cases.verify( @@ -964,10 +967,10 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() - ref0 = np.asarray(out0).copy() - ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] - ref1 = np.asarray(out1).copy() - ref1[1:9, 4:6] = inp1[1:9, 4:6] + ref0 = out0.asnumpy().copy() + ref0[1:9, 4:6] = inp0.asnumpy()[1:9, 4:6] + inp1.asnumpy()[1:9, 4:6] + ref1 = out1.asnumpy().copy() + ref1[1:9, 4:6] = inp1.asnumpy()[1:9, 4:6] cases.verify( cartesian_case, @@ -995,7 +998,7 @@ def fieldop_where_k_offset( )() out = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() - ref = np.where(np.asarray(k_index) > 0, np.roll(inp, 1, axis=1), 2) + ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), 2) cases.verify(cartesian_case, fieldop_where_k_offset, inp, k_index, out=out, ref=ref) @@ -1119,13 +1122,6 @@ def _invalid_unpack() -> tuple[int32, float64, int32]: def test_constant_closure_vars(cartesian_case): - if cartesian_case.backend is None: - # >>> field = gtx.zeros(domain) - # >>> np.int32(1)*field # steals the buffer from the field - # array([0.]) - - # TODO(havogt): remove `__array__`` from `NdArrayField` - pytest.xfail("Bug: Binary operation between np datatype and Field returns ndarray.") from gt4py.eve.utils import FrozenNamespace constants = FrozenNamespace( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 8bc325d276..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -155,8 +155,8 @@ def conditional_nested_tuple( b, out=cases.allocate(cartesian_case, conditional_nested_tuple, cases.RETURN)(), ref=np.where( - mask, - ((a, b), (b, a)), + mask.asnumpy(), + ((a.asnumpy(), b.asnumpy()), (b.asnumpy(), a.asnumpy())), ((np.full(size, 5.0), np.full(size, 7.0)), (np.full(size, 7.0), np.full(size, 5.0))), ), ) @@ -219,7 +219,15 @@ def conditional( b = cases.allocate(cartesian_case, conditional, "b")() out = cases.allocate(cartesian_case, conditional, cases.RETURN)() - cases.verify(cartesian_case, conditional, mask, a, b, out=out, ref=np.where(mask, a, b)) + cases.verify( + cartesian_case, + conditional, + mask, + a, + b, + out=out, + ref=np.where(mask.asnumpy(), a.asnumpy(), b.asnumpy()), + ) def test_conditional_promotion(cartesian_case): @@ -231,10 +239,9 @@ def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_promotion, "a")() out = cases.allocate(cartesian_case, conditional_promotion, cases.RETURN)() + ref = np.where(mask.asnumpy(), a.asnumpy(), 10.0) - cases.verify( - cartesian_case, conditional_promotion, mask, a, out=out, ref=np.where(mask, a, 10.0) - ) + cases.verify(cartesian_case, conditional_promotion, mask, a, out=out, ref=ref) def test_conditional_compareop(cartesian_case): @@ -279,7 +286,7 @@ def conditional_program( b, out, inout=out, - ref=np.where(mask, a, b)[1:], + ref=np.where(mask.asnumpy(), a.asnumpy(), b.asnumpy())[1:], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 937b05e087..8cfcff160c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -135,4 +135,4 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp builtin_field_op(*inps, out=out, offset_provider={}) - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index b82cae25a8..a0f69f332c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -152,7 +152,7 @@ def prog( cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={}) - assert np.allclose((a, b), (out_a, out_b)) + assert np.allclose((a.asnumpy(), b.asnumpy()), (out_a.asnumpy(), out_b.asnumpy())) def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): @@ -178,7 +178,9 @@ def prog( cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={}) - assert np.allclose((a[1:], b[1:]), (out_a[1:], out_b[1:])) + assert np.allclose( + (a[1:].asnumpy(), b[1:].asnumpy()), (out_a[1:].asnumpy(), out_b[1:].asnumpy()) + ) assert out_a[0] == 0 and out_b[0] == 0 @@ -209,7 +211,9 @@ def prog( cases.run(cartesian_case, prog, a, b, c, out_a, out_b, out_c, offset_provider={}) - assert np.allclose((a, b, c), (out_a, out_b, out_c)) + assert np.allclose( + (a.asnumpy(), b.asnumpy(), c.asnumpy()), (out_a.asnumpy(), out_b.asnumpy(), out_c.asnumpy()) + ) def test_wrong_argument_type(cartesian_case, copy_program_def): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index e9c3ac8d19..84b480a23d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -315,10 +315,10 @@ def if_without_else( out = cases.allocate(cartesian_case, if_without_else, cases.RETURN)() ref = { - (True, True): np.asarray(a) + 2, - (True, False): np.asarray(a), - (False, True): np.asarray(b) + 1, - (False, False): np.asarray(b) + 1, + (True, True): a.asnumpy() + 2, + (True, False): a.asnumpy(), + (False, True): b.asnumpy() + 1, + (False, False): b.asnumpy() + 1, } cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 2bcd0f8367..c0d565bbf4 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -178,7 +178,7 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp fencil(builtin, out, *inps, processor=program_processor, as_column=as_column) if validate: - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) @@ -199,7 +199,7 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): ) # avoid inlining the function fencil(builtin, out, *inps, processor=gtfn_without_transforms) - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) @pytest.mark.parametrize("as_column", [False, True]) @@ -228,7 +228,7 @@ def test_math_function_builtins(program_processor, builtin_name, inputs, as_colu ) if validate: - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) Neighbor = offset("Neighbor") @@ -268,7 +268,7 @@ def test_can_deref(program_processor, stencil): ) if validate: - assert np.allclose(np.asarray(out), -1.0) + assert np.allclose(out.asnumpy(), -1.0) a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) run_processor( @@ -280,7 +280,7 @@ def test_can_deref(program_processor, stencil): ) if validate: - assert np.allclose(np.asarray(out), 1.0) + assert np.allclose(out.asnumpy(), 1.0) # def test_can_deref_lifted(program_processor): @@ -336,7 +336,7 @@ def test_cast(program_processor, as_column, input_value, dtype, np_dtype): def sten_cast(it, casted_valued): return eq(cast_(deref(it), dtype), deref(casted_valued)) - out = field_maker(np.zeros_like(inp, dtype=builtins.bool))[0] + out = field_maker(np.zeros_like(inp.asnumpy(), dtype=builtins.bool))[0] run_processor( sten_cast[{IDim: range(1)}], program_processor, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index de7ebf2869..8536dbea90 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -51,5 +51,5 @@ def test_conditional_w_tuple(program_processor): offset_provider={}, ) if validate: - assert np.all(out.ndarray[np.asarray(inp) == 0] == 3.0) - assert np.all(out.ndarray[np.asarray(inp) == 1] == 7.0) + assert np.all(out.asnumpy()[inp.asnumpy() == 0] == 3.0) + assert np.all(out.asnumpy()[inp.asnumpy() == 1] == 7.0) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py index 83a86319b4..faae549086 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py @@ -31,8 +31,8 @@ def constant_stencil(): # this is traced as a lambda, TODO directly feed iterat return deref(inp) + deref(lift(constant_stencil)()) inp = gtx.as_field([IDim], np.asarray([0, 42], dtype=np.int32)) - res = gtx.as_field([IDim], np.zeros_like(inp)) + res = gtx.as_field([IDim], np.zeros_like(inp.asnumpy())) add_constant[{IDim: range(2)}](inp, out=res, offset_provider={}, backend=roundtrip.executor) - assert np.allclose(res, np.asarray([1, 43])) + assert np.allclose(res.asnumpy(), np.asarray([1, 43])) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index f9bd2cc33b..69f594a2bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -82,7 +82,7 @@ def test_simple_indirection(program_processor): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) @fundef @@ -113,4 +113,4 @@ def test_direct_offset_for_indirection(program_processor): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 2df7691f9e..6f600414db 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -53,7 +53,7 @@ def test_single_argument(program_processor, dom): run_processor(copy_stencil[dom], program_processor, inp, out=out, offset_provider={}) if validate: - assert np.allclose(inp, out) + assert np.allclose(inp.asnumpy(), out.asnumpy()) def test_2_arguments(program_processor, dom): @@ -70,7 +70,7 @@ def fun(inp0, inp1): run_processor(fun[dom], program_processor, inp0, inp1, out=out, offset_provider={}) if validate: - assert np.allclose(inp0 + inp1, out) + assert np.allclose(inp0.asnumpy() + inp1.asnumpy(), out.asnumpy()) def test_lambda_domain(program_processor): @@ -82,4 +82,4 @@ def test_lambda_domain(program_processor): run_processor(copy_stencil[dom], program_processor, inp, out=out, offset_provider={}) if validate: - assert np.allclose(inp, out) + assert np.allclose(inp.asnumpy(), out.asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index 3af0440c27..fce1aa3960 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -60,4 +60,4 @@ def wrapped(inp): ) if validate: - assert np.allclose(out[:, :-1], reference) + assert np.allclose(out[:, :-1].asnumpy(), reference) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index abdfffd74e..dd603fa3be 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -63,9 +63,9 @@ def test_strided_offset_provider(program_processor): ), ) out = gtx.as_field([LocA], np.zeros((LocA_size,))) - ref = np.sum(np.asarray(inp).reshape(LocA_size, max_neighbors), axis=-1) + ref = np.sum(inp.asnumpy().reshape(LocA_size, max_neighbors), axis=-1) run_processor(fencil, program_processor, LocA_size, out, inp) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 8c59f994ee..8e12647c1b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -65,7 +65,7 @@ def test_trivial(program_processor, lift_mode): ) if validate: - assert np.allclose(out[:, :, 0], out_s) + assert np.allclose(out[:, :, 0], out_s.asnumpy()) @fundef @@ -100,7 +100,7 @@ def test_shifted_arg_to_lift(program_processor, lift_mode): ) if validate: - assert np.allclose(out, out_s) + assert np.allclose(out, out_s.asnumpy()) @fendef @@ -137,7 +137,7 @@ def test_direct_deref(program_processor, lift_mode): ) if validate: - assert np.allclose(out, out_s) + assert np.allclose(out, out_s.asnumpy()) @fundef @@ -167,4 +167,4 @@ def test_vertical_shift_unstructured(program_processor): ) if validate: - assert np.allclose(inp_s[:, 1:], np.asarray(out_s)[:, :-1]) + assert np.allclose(inp_s[:, 1:].asnumpy(), out_s[:, :-1].asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 97a51508f5..add772e7ef 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -76,8 +76,8 @@ def test_tuple_output(program_processor, stencil): } run_processor(stencil[dom], program_processor, inp1, inp2, out=out, offset_provider={}) if validate: - assert np.allclose(inp1, out[0]) - assert np.allclose(inp2, out[1]) + assert np.allclose(inp1.asnumpy(), out[0].asnumpy()) + assert np.allclose(inp2.asnumpy(), out[1].asnumpy()) @fundef @@ -144,10 +144,10 @@ def stencil(inp1, inp2, inp3, inp4): offset_provider={}, ) if validate: - assert np.allclose(inp1, out[0][0]) - assert np.allclose(inp2, out[0][1]) - assert np.allclose(inp3, out[1][0]) - assert np.allclose(inp4, out[1][1]) + assert np.allclose(inp1.asnumpy(), out[0][0].asnumpy()) + assert np.allclose(inp2.asnumpy(), out[0][1].asnumpy()) + assert np.allclose(inp3.asnumpy(), out[1][0].asnumpy()) + assert np.allclose(inp4.asnumpy(), out[1][1].asnumpy()) @pytest.mark.parametrize( @@ -197,8 +197,8 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): offset_provider={}, ) if validate: - assert np.allclose(inp1, out1) - assert np.allclose(inp2, out2) + assert np.allclose(inp1.asnumpy(), out1.asnumpy()) + assert np.allclose(inp2.asnumpy(), out2.asnumpy()) def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): @@ -255,9 +255,9 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): offset_provider={}, ) if validate: - assert np.allclose(inp1, out1) - assert np.allclose(inp2, out2) - assert np.allclose(inp3, out3) + assert np.allclose(inp1.asnumpy(), out1.asnumpy()) + assert np.allclose(inp2.asnumpy(), out2.asnumpy()) + assert np.allclose(inp3.asnumpy(), out3.asnumpy()) @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") @@ -323,7 +323,7 @@ def test_tuple_field_input(program_processor): } run_processor(tuple_input[dom], program_processor, (inp1, inp2), out=out, offset_provider={}) if validate: - assert np.allclose(np.asarray(inp1) + np.asarray(inp2), out) + assert np.allclose(inp1.asnumpy() + inp2.asnumpy(), out.asnumpy()) @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") @@ -389,7 +389,7 @@ def test_tuple_of_tuple_of_field_input(program_processor): ) if validate: assert np.allclose( - (np.asarray(inp1) + np.asarray(inp2) + np.asarray(inp3) + np.asarray(inp4)), out + (inp1.asnumpy() + inp2.asnumpy() + inp3.asnumpy() + inp4.asnumpy()), out.asnumpy() ) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 3f229ef389..579dec11f8 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -41,30 +41,30 @@ def mixed_args( def test_allocate_default_unique(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, mixed_args, "a")() - assert np.min(a) == 0 - assert np.max(a) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 + assert np.min(a.asnumpy()) == 0 + assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 b = cases.allocate(cartesian_case, mixed_args, "b")() - assert b == np.max(a) + 1 + assert b == np.max(a.asnumpy()) + 1 c = cases.allocate(cartesian_case, mixed_args, "c")() - assert np.min(c) == b + 1 - assert np.max(c) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + assert np.min(c.asnumpy()) == b + 1 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 def test_allocate_return_default_zeros(cartesian_case): # noqa: F811 # fixtures a, (b, c) = cases.allocate(cartesian_case, mixed_args, cases.RETURN)() - assert np.all(np.asarray(a) == 0) - assert np.all(np.asarray(a) == b) - assert np.all(np.asarray(b) == c) + assert np.all(a.asnumpy() == 0) + assert np.all(b.asnumpy() == 0) + assert np.all(c.asnumpy() == 0) def test_allocate_const(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, mixed_args, "a").strategy(cases.ConstInitializer(42))() - assert np.all(np.asarray(a) == 42) + assert np.all(a.asnumpy() == 42) b = cases.allocate(cartesian_case, mixed_args, "b").strategy(cases.ConstInitializer(42))() assert b == 42.0 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index cd948ffa02..8b4cedd98b 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -244,7 +244,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) + assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:].asnumpy()) @pytest.mark.uses_tuple_returns @@ -286,8 +286,8 @@ def test_solve_nonhydro_stencil_52_like(test_setup): test_setup.dummy, ) - assert np.allclose(test_setup.z_q_ref, test_setup.z_q) - assert np.allclose(test_setup.w_ref, test_setup.w) + assert np.allclose(test_setup.z_q_ref, test_setup.z_q.asnumpy()) + assert np.allclose(test_setup.w_ref, test_setup.w.asnumpy()) @pytest.mark.uses_tuple_returns @@ -306,5 +306,5 @@ def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): test_setup.w, ) - assert np.allclose(test_setup.z_q_ref, test_setup.z_q) - assert np.allclose(test_setup.w_ref, test_setup.w) + assert np.allclose(test_setup.z_q_ref, test_setup.z_q.asnumpy()) + assert np.allclose(test_setup.w_ref, test_setup.w.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 829bc497cb..806ab7eb9a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -103,4 +103,4 @@ def test_anton_toy(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index d05b14d73d..fd571514ac 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -16,6 +16,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import utils from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset @@ -89,7 +90,7 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ) out = gtx.as_field([IDim, KDim], np.zeros(shape)) - ref = ref_fun(inp) + ref = ref_fun(inp.asnumpy()) run_processor( stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}], @@ -102,7 +103,7 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) @fundef @@ -157,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct k_size = 5 inp = inp_function(k_size) - ref = ref_function(inp) + ref = ref_function(utils.asnumpy(inp)) out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) @@ -173,7 +174,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct ) if validate: - np.allclose(ref, out) + np.allclose(ref, out.asnumpy()) @fundef @@ -222,7 +223,7 @@ def test_ksum_scan(program_processor, lift_mode, kstart, reference): ) if validate: - assert np.allclose(reference, np.asarray(out)) + assert np.allclose(reference, out.asnumpy()) @fundef @@ -260,7 +261,7 @@ def test_ksum_back_scan(program_processor, lift_mode): ) if validate: - assert np.allclose(ref, np.asarray(out)) + assert np.allclose(ref, out.asnumpy()) @fundef @@ -366,7 +367,7 @@ def test_different_vertical_sizes(program_processor): ) if validate: - assert np.allclose(ref[1:], out[1:]) + assert np.allclose(ref[1:], out.asnumpy()[1:]) @fundef @@ -392,7 +393,7 @@ def test_different_vertical_sizes_with_origin(program_processor): inp0 = gtx.as_field([KDim], np.arange(0, k_size)) inp1 = gtx.as_field([KDim], np.arange(0, k_size + 1), origin={KDim: 1}) out = gtx.as_field([KDim], np.zeros(k_size, dtype=np.int64)) - ref = np.asarray(inp0) + np.asarray(inp1)[:-1] + ref = inp0.asnumpy() + inp1.asnumpy()[:-1] run_processor( sum_fencil, @@ -405,7 +406,7 @@ def test_different_vertical_sizes_with_origin(program_processor): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) # TODO(havogt) test tuple_get builtin on a Column diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 47867b9a64..e1d959aba9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -159,8 +159,8 @@ def test_compute_zavgS(program_processor, lift_mode): ) if validate: - assert_close(-199755464.25741270, np.min(zavgS)) - assert_close(388241977.58389181, np.max(zavgS)) + assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) + assert_close(388241977.58389181, np.max(zavgS.asnumpy())) run_processor( compute_zavgS_fencil, @@ -173,8 +173,8 @@ def test_compute_zavgS(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert_close(-1000788897.3202186, np.min(zavgS)) - assert_close(1000788897.3202186, np.max(zavgS)) + assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) + assert_close(1000788897.3202186, np.max(zavgS.asnumpy())) @fendef @@ -222,11 +222,11 @@ def test_compute_zavgS2(program_processor, lift_mode): ) if validate: - assert_close(-199755464.25741270, np.min(zavgS[0])) - assert_close(388241977.58389181, np.max(zavgS[0])) + assert_close(-199755464.25741270, np.min(zavgS[0].asnumpy())) + assert_close(388241977.58389181, np.max(zavgS[0].asnumpy())) - assert_close(-1000788897.3202186, np.min(zavgS[1])) - assert_close(1000788897.3202186, np.max(zavgS[1])) + assert_close(-1000788897.3202186, np.min(zavgS[1].asnumpy())) + assert_close(1000788897.3202186, np.max(zavgS[1].asnumpy())) @pytest.mark.requires_atlas @@ -266,10 +266,10 @@ def test_nabla(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) @fendef @@ -322,10 +322,10 @@ def test_nabla2(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) @fundef @@ -407,7 +407,7 @@ def test_nabla_sign(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 8aabd18267..9bba1ab89c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -97,4 +97,4 @@ def test_hdiff(hdiff_reference, program_processor, lift_mode): ) if validate: - assert np.allclose(out[:, :, 0], out_s) + assert np.allclose(out[:, :, 0], out_s.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 29c82442ea..f2a6505a7e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -158,4 +158,4 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): ) if validate: - assert np.allclose(x, x_s) + assert np.allclose(x, x_s.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6354e45451..000d3c4822 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -99,7 +99,7 @@ def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -122,7 +122,7 @@ def test_map_neighbors(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -146,7 +146,7 @@ def test_map_make_const_list(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -172,7 +172,7 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -200,7 +200,7 @@ def test_sparse_input_field(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) def test_sparse_input_field_v2v(program_processor, lift_mode): @@ -226,7 +226,7 @@ def test_sparse_input_field_v2v(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -254,7 +254,7 @@ def test_slice_sparse(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -309,7 +309,7 @@ def test_shift_sliced_sparse(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -337,7 +337,7 @@ def test_slice_shifted_sparse(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -365,7 +365,7 @@ def test_lift(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -390,7 +390,7 @@ def test_shift_sparse_input_field(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -443,7 +443,7 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): ) if validate: - assert np.allclose(out1, out2) + assert np.allclose(out1.asnumpy(), out2.asnumpy()) @fundef @@ -484,4 +484,4 @@ def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): ) if validate: - assert np.allclose(np.asarray(out), ref) + assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py index d851c5560a..c91be04999 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py @@ -41,4 +41,4 @@ def copy(inp: gtx.Field[[IDim, JDim], gtx.int32]) -> gtx.Field[[IDim, JDim], gtx copy(inp, out=out, offset_provider={}) - assert np.allclose(inp[:out_nx, :out_ny], out) + assert np.allclose(inp[:out_nx, :out_ny].asnumpy(), out.asnumpy()) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 00dbf68274..436e672cc5 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -555,8 +555,8 @@ def test_setitem(index, value): domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) - expected = np.copy(field.ndarray) - expected[index] = value + expected = np.copy(field.asnumpy()) + expected[index] = value.asnumpy() if common.is_field(value) else value field[index] = value From 67e5270729a5951f3902942052afa423d778b965 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 21 Nov 2023 22:17:54 +0100 Subject: [PATCH 40/67] feature[next]: remap and connectivity implementations for embedded (#1309) Adds - ConnectivityField protocol - NdArrayFieldConnectivity for unstructured remap, CartesianConnectivity for Cartesian remap - implements neighbor_sum, max_over, min_over for field TODOs for next PR: - support for remap with connectivities with has_skip_values=True --------- Co-authored-by: Enrique Gonzalez Paredes Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/_core/definitions.py | 27 ++ src/gt4py/next/__init__.py | 9 +- src/gt4py/next/common.py | 297 +++++++++++++++++- src/gt4py/next/constructors.py | 70 ++++- src/gt4py/next/embedded/__init__.py | 10 + src/gt4py/next/embedded/common.py | 3 +- src/gt4py/next/embedded/context.py | 64 ++++ src/gt4py/next/embedded/nd_array_field.py | 283 +++++++++++++++-- src/gt4py/next/ffront/decorator.py | 25 +- src/gt4py/next/ffront/fbuiltins.py | 82 +++-- src/gt4py/next/iterator/embedded.py | 25 +- tests/next_tests/exclusion_matrices.py | 2 - tests/next_tests/integration_tests/cases.py | 2 +- .../ffront_tests/ffront_test_utils.py | 2 +- .../ffront_tests/test_execution.py | 6 +- .../ffront_tests/test_icon_like_scan.py | 2 +- .../embedded_tests/test_basic_program.py | 47 +++ .../embedded_tests/test_nd_array_field.py | 66 +++- tests/next_tests/unit_tests/test_common.py | 129 ++++++++ 19 files changed, 1057 insertions(+), 94 deletions(-) create mode 100644 src/gt4py/next/embedded/context.py create mode 100644 tests/next_tests/unit_tests/embedded_tests/test_basic_program.py diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 79543a1849..0e6301ae0f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -490,3 +490,30 @@ def __rtruediv__(self, other: Any) -> NDArrayObject: def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + + def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + ... + + def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + ... + + def __gt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __ge__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __lt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __le__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... + + def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... + + def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 696c4f174c..cbd5735949 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -24,8 +24,8 @@ """ from . import common, ffront, iterator, program_processors, type_inference -from .common import Dimension, DimensionKind, Field, GridType -from .constructors import as_field, empty, full, ones, zeros +from .common import Dimension, DimensionKind, Domain, Field, GridType, UnitRange, domain, unit_range +from .constructors import as_connectivity, as_field, empty, full, ones, zeros from .embedded import ( # Just for registering field implementations nd_array_field as _nd_array_field, ) @@ -53,12 +53,17 @@ "DimensionKind", "Field", "GridType", + "domain", + "Domain", + "unit_range", + "UnitRange", # from constructors "empty", "zeros", "ones", "full", "as_field", + "as_connectivity", # from iterator "NeighborTableOffsetProvider", "StridedNeighborOffsetProvider", diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 51ad14f22d..7f1ad8c0bb 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -19,10 +19,10 @@ import dataclasses import enum import functools +import numbers import sys import types from collections.abc import Mapping, Sequence, Set -from typing import overload import numpy as np import numpy.typing as npt @@ -33,6 +33,7 @@ Any, Callable, ClassVar, + Never, Optional, ParamSpec, Protocol, @@ -41,14 +42,14 @@ TypeVar, cast, extended_runtime_checkable, + overload, runtime_checkable, ) from gt4py.eve.type_definitions import StrEnum -DimsT = TypeVar( - "DimsT", covariant=True -) # bound to `Sequence[Dimension]` if instance of Dimension would be a type +DimT = TypeVar("DimT", bound="Dimension") # , covariant=True) +DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) class Infinity(int): @@ -61,6 +62,9 @@ def negative(cls) -> Infinity: return cls(-sys.maxsize) +Tag: TypeAlias = str + + @enum.unique class DimensionKind(StrEnum): HORIZONTAL = "horizontal" @@ -96,6 +100,7 @@ def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScal object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) + # TODO: the whole infinity idea and implementation is broken and should be replaced @classmethod def infinity(cls) -> UnitRange: return cls(Infinity.negative(), Infinity.positive()) @@ -113,10 +118,10 @@ def __getitem__(self, index: int) -> int: ... @overload - def __getitem__(self, index: slice) -> UnitRange: + def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unused ... - def __getitem__(self, index: int | slice) -> int | UnitRange: + def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: @@ -149,6 +154,32 @@ def __le__(self, other: Set[int]): else: return Set.__le__(self, other) + def __add__(self, other: int | Set[int]) -> UnitRange: + if isinstance(other, int): + if other == Infinity.positive(): + return UnitRange.infinity() + elif other == Infinity.negative(): + return UnitRange(0, 0) + return UnitRange( + *( + s if s in [Infinity.negative(), Infinity.positive()] else s + other + for s in (self.start, self.stop) + ) + ) + else: + raise NotImplementedError("Can only compute union with int instances.") + + def __sub__(self, other: int | Set[int]) -> UnitRange: + if isinstance(other, int): + if other == Infinity.negative(): + return self + Infinity.positive() + elif other == Infinity.positive(): + return self + Infinity.negative() + else: + return self + (-other) + else: + raise NotImplementedError("Can only compute substraction with int instances.") + __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented def __str__(self) -> str: @@ -184,8 +215,8 @@ def unit_range(r: RangeLike) -> UnitRange: IntIndex: TypeAlias = int | core_defs.IntegralScalar -NamedIndex: TypeAlias = tuple[Dimension, IntIndex] -NamedRange: TypeAlias = tuple[Dimension, UnitRange] +NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple +NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement @@ -260,8 +291,8 @@ class Domain(Sequence[NamedRange]): def __init__( self, *args: NamedRange, - dims: Optional[tuple[Dimension, ...]] = None, - ranges: Optional[tuple[UnitRange, ...]] = None, + dims: Optional[Sequence[Dimension]] = None, + ranges: Optional[Sequence[UnitRange]] = None, ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: @@ -285,8 +316,8 @@ def __init__( f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." ) - object.__setattr__(self, "dims", dims) - object.__setattr__(self, "ranges", ranges) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") @@ -300,6 +331,10 @@ def __init__( def __len__(self) -> int: return len(self.ranges) + @property + def ndim(self) -> int: + return len(self.dims) + @property def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) @@ -309,14 +344,16 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> Domain: + def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused ... @overload - def __getitem__(self, index: Dimension) -> NamedRange: + def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused ... - def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: + def __getitem__( # noqa: F811 # redefine unused + self, index: int | slice | Dimension + ) -> NamedRange | Domain: # noqa: F811 # redefine unused if isinstance(index, int): return self.dims[index], self.ranges[index] elif isinstance(index, slice): @@ -360,6 +397,36 @@ def __and__(self, other: Domain) -> Domain: def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + def dim_index(self, dim: Dimension) -> Optional[int]: + return self.dims.index(dim) if dim in self.dims else None + + def pop(self, index: int | Dimension = -1) -> Domain: + return self.replace(index) + + def insert(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: + if isinstance(index, int) and index == len(self.dims): + new_dims, new_ranges = zip(*named_ranges) + return Domain(dims=self.dims + new_dims, ranges=self.ranges + new_ranges) + else: + return self.replace(index, *named_ranges) + + def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: + assert all(is_named_range(nr) for nr in named_ranges) + if isinstance(index, Dimension): + dim_index = self.dim_index(index) + if dim_index is None: + raise ValueError(f"Dimension {index} not found in Domain.") + index = dim_index + if not (-len(self.dims) <= index < len(self.dims)): + raise IndexError(f"Index {index} out of bounds for Domain of length {len(self.dims)}.") + if index < 0: + index += len(self.dims) + new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) + dims = self.dims[:index] + new_dims + self.dims[index + 1 :] + ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :] + + return Domain(dims=dims, ranges=ranges) + DomainLike: TypeAlias = ( Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] @@ -456,6 +523,10 @@ class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, co def domain(self) -> Domain: ... + @property + def codomain(self) -> type[core_defs.ScalarT] | Dimension: + ... + @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... @@ -472,7 +543,7 @@ def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def remap(self, index_field: Field) -> Field: + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod @@ -481,7 +552,7 @@ def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: # Operators @abc.abstractmethod - def __call__(self, index_field: Field) -> Field: + def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod @@ -592,6 +663,100 @@ def is_mutable_field( return isinstance(v, MutableField) # type: ignore[misc] # we use extended_runtime_checkable +class ConnectivityKind(enum.Flag): + MODIFY_DIMS = enum.auto() + MODIFY_RANK = enum.auto() + MODIFY_STRUCTURE = enum.auto() + + +@extended_runtime_checkable +class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): # type: ignore[misc] # DimT should be covariant, but break in another place + @property + @abc.abstractmethod + def codomain(self) -> DimT: + ... + + @property + def kind(self) -> ConnectivityKind: + return ( + ConnectivityKind.MODIFY_DIMS + | ConnectivityKind.MODIFY_RANK + | ConnectivityKind.MODIFY_STRUCTURE + ) + + @abc.abstractmethod + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: + ... + + # Operators + def __abs__(self) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __neg__(self) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __invert__(self) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __eq__(self, other: Any) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __ne__(self, other: Any) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + +def is_connectivity_field( + v: Any, +) -> TypeGuard[ConnectivityField]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed + return isinstance(v, ConnectivityField) # type: ignore[misc] # we use extended_runtime_checkable + + @functools.singledispatch def field( definition: Any, @@ -603,6 +768,18 @@ def field( raise NotImplementedError +@functools.singledispatch +def connectivity( + definition: Any, + /, + codomain: Dimension, + *, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DType] = None, +) -> ConnectivityField: + raise NotImplementedError + + @dataclasses.dataclass(frozen=True) class GTInfo: definition: Any @@ -638,6 +815,92 @@ class NeighborTable(Connectivity, Protocol): table: npt.NDArray +OffsetProviderElem: TypeAlias = Dimension | Connectivity +OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] + + +@dataclasses.dataclass(frozen=True, eq=False) +class CartesianConnectivity(ConnectivityField[DimsT, DimT]): + dimension: DimT + offset: int = 0 + + @classmethod + def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] + raise NotImplementedError() + + @property + def ndarray(self) -> Never: + raise NotImplementedError() + + def asnumpy(self) -> Never: + raise NotImplementedError() + + @functools.cached_property + def domain(self) -> Domain: + return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),)) + + @property + def __gt_dims__(self) -> tuple[Dimension, ...]: + return self.domain.dims + + @property + def __gt_origin__(self) -> Never: + raise TypeError("CartesianConnectivity does not support this operation") + + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + + @functools.cached_property + def codomain(self) -> DimT: + return self.dimension + + @functools.cached_property + def kind(self) -> ConnectivityKind: + return ConnectivityKind(0) + + @classmethod + def from_offset( + cls, + definition: int, + /, + codomain: DimT, + *, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DTypeLike] = None, + ) -> CartesianConnectivity: + assert domain is None + assert dtype is None + return cls(codomain, definition) + + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: + if not isinstance(image_range, UnitRange): + if image_range[0] != self.codomain: + raise ValueError( + f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + ) + + image_range = image_range[1] + + assert isinstance(image_range, UnitRange) + return ((self.codomain, image_range - self.offset),) + + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: + raise NotImplementedError() + + __call__ = remap + + def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar: + if is_int_index(index): + return index + self.offset + raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case + + __getitem__ = restrict + + +connectivity.register(numbers.Integral, CartesianConnectivity.from_offset) + + @enum.unique class GridType(StrEnum): CARTESIAN = "cartesian" diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 42b0bcda90..63fde1cfde 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -212,7 +212,7 @@ def as_field( This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. See :func:`empty` for further details about the meaning of the extra keyword arguments. - Parameters: + Arguments: domain: Definition of the domain of the field (and consequently of the shape of the allocated field buffer). In addition to the values allowed in `empty`, it can also just be a sequence of dimensions, in which case the sizes of each dimension will then be taken from the shape of `data`. @@ -283,7 +283,7 @@ def as_field( dtype = core_defs.dtype(dtype) assert dtype.tensor_shape == () # TODO - if allocator is device is None and xtyping.supports_dlpack(data): + if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) field = empty( @@ -297,3 +297,69 @@ def as_field( field[...] = field.array_ns.asarray(data) return field + + +@eve.utils.with_fluid_partial +def as_connectivity( + domain: common.DomainLike | Sequence[common.Dimension], + codomain: common.Dimension, + data: core_defs.NDArrayObject, + dtype: Optional[core_defs.DType] = None, + *, + allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, + device: Optional[core_defs.Device] = None, + # copy=False, TODO +) -> common.ConnectivityField: + """ + Construct a connectivity field from the given domain, codomain, and data. + + Arguments: + domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + sequence of `common.Dimension` objects. + codomain: The codomain dimension of the connectivity field. + data: The data used to construct the connectivity field. + dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + a default allocator will be used. + device: The device on which the connectivity field will be allocated. If not provided, the default + device will be used. + + Returns: + The constructed connectivity field. + + Raises: + ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. + """ + if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): + domain = cast(Sequence[common.Dimension], domain) + if len(domain) != data.ndim: + raise ValueError( + f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + ) + actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)]) + else: + actual_domain = common.domain(cast(common.DomainLike, domain)) + + if not isinstance(codomain, common.Dimension): + raise ValueError(f"Invalid codomain dimension `{codomain}`") + + # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has + # already the correct layout and device. + shape = storage_utils.asarray(data).shape + if shape != actual_domain.shape: + raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + if dtype is None: + dtype = storage_utils.asarray(data).dtype + dtype = core_defs.dtype(dtype) + assert dtype.tensor_shape == () # TODO + + if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): + device = core_defs.Device(*data.__dlpack_device__()) + buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) + buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] # TODO(havogt): consider addin MutableNDArrayObject + connectivity_field = common.connectivity( + buffer.ndarray, codomain=codomain, domain=actual_domain + ) + assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) + + return connectivity_field diff --git a/src/gt4py/next/embedded/__init__.py b/src/gt4py/next/embedded/__init__.py index 6c43e2f12a..e0cb114148 100644 --- a/src/gt4py/next/embedded/__init__.py +++ b/src/gt4py/next/embedded/__init__.py @@ -11,3 +11,13 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + +from . import common, context, exceptions, nd_array_field + + +__all__ = [ + "common", + "context", + "exceptions", + "nd_array_field", +] diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 37ba4954f3..d796189ab3 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -12,8 +12,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, Sequence, cast +from __future__ import annotations +from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py new file mode 100644 index 0000000000..5fbdbc6f25 --- /dev/null +++ b/src/gt4py/next/embedded/context.py @@ -0,0 +1,64 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +import contextlib +import contextvars as cvars +from typing import Any + +import gt4py.eve as eve +import gt4py.next.common as common + + +#: Column range used in column mode (`column_axis != None`) in the current embedded iterator +#: closure execution context. +closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range") + +_undefined_offset_provider: common.OffsetProvider = {} + +#: Offset provider dict in the current embedded execution context. +offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar( + "offset_provider", default=_undefined_offset_provider +) + + +@contextlib.contextmanager +def new_context( + *, + closure_column_range: range | eve.NothingType = eve.NOTHING, + offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, +): + import gt4py.next.embedded.context as this_module + + updates: list[tuple[cvars.ContextVar[Any], Any]] = [] + if closure_column_range is not eve.NOTHING: + updates.append((this_module.closure_column_range, closure_column_range)) + if offset_provider is not eve.NOTHING: + updates.append((this_module.offset_provider, offset_provider)) + + # Create new context with provided values + ctx = cvars.copy_context() + + def ctx_updater(*args): + for cvar, value in args: + cvar.set(value) + + ctx.run(ctx_updater, *updates) + + yield ctx + + +def within_context() -> bool: + return offset_provider.get() is not _undefined_offset_provider diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index a843772a20..ff6a2ceac7 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -19,12 +19,13 @@ import operator from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar +from typing import ClassVar import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs +from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common from gt4py.next.embedded import common as embedded_common from gt4py.next.ffront import fbuiltins @@ -126,6 +127,10 @@ def asnumpy(self) -> np.ndarray: else: return np.asarray(self._ndarray) + @property + def codomain(self) -> type[core_defs.ScalarT]: + return self.dtype.scalar_type + @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: return core_defs.dtype(self._ndarray.dtype.type) @@ -153,12 +158,53 @@ def from_array( assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim - assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape)) + assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) return cls(domain, array) - def remap(self: NdArrayField, connectivity) -> NdArrayField: - raise NotImplementedError() + def remap( + self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset + ) -> NdArrayField: + # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField + if not common.is_connectivity_field(connectivity): + assert isinstance(connectivity, fbuiltins.FieldOffset) + connectivity = connectivity.as_connectivity_field() + + assert common.is_connectivity_field(connectivity) + + # Compute the new domain + dim = connectivity.codomain + dim_idx = self.domain.dim_index(dim) + if dim_idx is None: + raise ValueError(f"Incompatible index field, expected a field with dimension {dim}.") + + current_range: common.UnitRange = self.domain[dim_idx][1] + new_ranges = connectivity.inverse_image(current_range) + new_domain = self.domain.replace(dim_idx, *new_ranges) + + # perform contramap + if not (connectivity.kind & common.ConnectivityKind.MODIFY_STRUCTURE): + # shortcut for compact remap: don't change the array, only the domain + new_buffer = self._ndarray + else: + # general case: first restrict the connectivity to the new domain + restricted_connectivity_domain = common.Domain(*new_ranges) + restricted_connectivity = ( + connectivity.restrict(restricted_connectivity_domain) + if restricted_connectivity_domain != connectivity.domain + else connectivity + ) + assert common.is_connectivity_field(restricted_connectivity) + + # then compute the index array + xp = self.array_ns + new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start + # finally, take the new array + new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx) + + return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype) + + __call__ = remap # type: ignore[assignment] def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) @@ -172,7 +218,22 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __getitem__ = restrict - __call__ = None # type: ignore[assignment] # TODO: remap + def __setitem__( + self: NdArrayField[common.DimsT, core_defs.ScalarT], + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, + ) -> None: + target_domain, target_slice = self._slice(index) + + if common.is_field(value): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self.ndarray, "__setitem__") + self._ndarray[target_slice] = value # type: ignore[index] # np and cp allow index assignment, jax overrides __abs__ = _make_builtin("abs", "abs") @@ -194,9 +255,17 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __mod__ = __rmod__ = _make_builtin("mod", "mod") - __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool` + __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore # mypy wants return `bool` + + __eq__ = _make_builtin("equal", "equal") # type: ignore # mypy wants return `bool` + + __gt__ = _make_builtin("greater", "greater") + + __ge__ = _make_builtin("greater_equal", "greater_equal") + + __lt__ = _make_builtin("less", "less") - __eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool` + __le__ = _make_builtin("less_equal", "less_equal") def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): @@ -239,6 +308,144 @@ def _slice( return new_domain, slice_ +@dataclasses.dataclass(frozen=True) +class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ + common.ConnectivityField[common.DimsT, common.DimT], + NdArrayField[common.DimsT, core_defs.IntegralScalar], +): + _codomain: common.DimT + + @functools.cached_property + def _cache(self) -> dict: + return {} + + @classmethod + def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] + raise NotImplementedError() + + @property + def codomain(self) -> common.DimT: # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + return self._codomain + + @functools.cached_property + def kind(self) -> common.ConnectivityKind: + kind = common.ConnectivityKind.MODIFY_STRUCTURE + if self.domain.ndim > 1: + kind |= common.ConnectivityKind.MODIFY_RANK + kind |= common.ConnectivityKind.MODIFY_DIMS + if self.domain.dim_index(self.codomain) is None: + kind |= common.ConnectivityKind.MODIFY_DIMS + + return kind + + @classmethod + def from_array( # type: ignore[override] + cls, + data: npt.ArrayLike | core_defs.NDArrayObject, + /, + codomain: common.DimT, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + ) -> NdArrayConnectivityField: + domain = common.domain(domain) + xp = cls.array_ns + + xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) + array = xp.asarray(data, dtype=xp_dtype) + + if dtype is not None: + assert array.dtype.type == core_defs.dtype(dtype).scalar_type + + assert issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES) + + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain + assert len(domain) == array.ndim + assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape)) + + assert isinstance(codomain, common.Dimension) + + return cls(domain, array, codomain) + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + cache_key = hash((id(self.ndarray), self.domain, image_range)) + + if (new_dims := self._cache.get(cache_key, None)) is None: + xp = self.array_ns + + if not isinstance( + image_range, common.UnitRange + ): # TODO(havogt): cleanup duplication with CartesianConnectivity + if image_range[0] != self.codomain: + raise ValueError( + f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + ) + + image_range = image_range[1] + + assert isinstance(image_range, common.UnitRange) + + restricted_mask = (self._ndarray >= image_range.start) & ( + self._ndarray < image_range.stop + ) + # indices of non-zero elements in each dimension + nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask) + + new_dims = [] + non_contiguous_dims = [] + + for i, dim_nnz_indices in enumerate(nnz): + # Check if the indices are contiguous + first_data_index = dim_nnz_indices[0] + assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES) + last_data_index = dim_nnz_indices[-1] + assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES) + indices, counts = xp.unique(dim_nnz_indices, return_counts=True) + if len(xp.unique(counts)) == 1 and ( + len(indices) == last_data_index - first_data_index + 1 + ): + dim_range = self._domain[i] + idx_offset = dim_range[1].start + start = idx_offset + first_data_index + assert common.is_int_index(start) + stop = idx_offset + last_data_index + 1 + assert common.is_int_index(stop) + new_dims.append( + common.named_range( + ( + dim_range[0], + (start, stop), + ) + ) + ) + else: + non_contiguous_dims.append(dim_range[0]) + + if non_contiguous_dims: + raise ValueError( + f"Restriction generates non-contiguous dimensions {non_contiguous_dims}" + ) + + return new_dims + + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar: + cache_key = (id(self.ndarray), self.domain, index) + + if (restricted_connectivity := self._cache.get(cache_key, None)) is None: + cls = self.__class__ + xp = cls.array_ns + new_domain, buffer_slice = self._slice(index) + new_buffer = xp.asarray(self.ndarray[buffer_slice]) + restricted_connectivity = cls(new_domain, new_buffer, self.codomain) + self._cache[cache_key] = restricted_connectivity + + return restricted_connectivity + + __getitem__ = restrict + + # -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] @@ -266,22 +473,30 @@ def _slice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _np_cp_setitem( - self: NdArrayField[common.DimsT, core_defs.ScalarT], - index: common.AnyIndexSpec, - value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, -) -> None: - target_domain, target_slice = self._slice(index) +def _make_reduction( + builtin_name: str, array_builtin_name: str +) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT],]: + def _builtin_op( + field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension + ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: + if not axis.kind == common.DimensionKind.LOCAL: + raise ValueError("Can only reduce local dimensions.") + if axis not in field.domain.dims: + raise ValueError(f"Field doesn't have dimension {axis}. Cannot reduce.") + reduce_dim_index = field.domain.dims.index(axis) + new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + return field.__class__.from_array( + getattr(field.array_ns, array_builtin_name)(field.ndarray, axis=reduce_dim_index), + domain=new_domain, + ) + + _builtin_op.__name__ = builtin_name + return _builtin_op - if common.is_field(value): - if not value.domain == target_domain: - raise ValueError( - f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." - ) - value = value.ndarray - assert hasattr(self.ndarray, "__setitem__") - self.ndarray[target_slice] = value +NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum")) +NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max")) +NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min")) # -- Concrete array implementations -- @@ -293,11 +508,17 @@ def _np_cp_setitem( class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np - __setitem__ = _np_cp_setitem - common.field.register(np.ndarray, NumPyArrayField.from_array) + +@dataclasses.dataclass(frozen=True, eq=False) +class NumPyArrayConnectivityField(NdArrayConnectivityField): + array_ns: ClassVar[ModuleType] = np + + +common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array) + # CuPy if cp: _nd_array_implementations.append(cp) @@ -306,10 +527,14 @@ class NumPyArrayField(NdArrayField): class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp - __setitem__ = _np_cp_setitem - common.field.register(cp.ndarray, CuPyArrayField.from_array) + @dataclasses.dataclass(frozen=True, eq=False) + class CuPyArrayConnectivityField(NdArrayConnectivityField): + array_ns: ClassVar[ModuleType] = cp + + common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array) + # JAX if jnp: _nd_array_implementations.append(jnp) @@ -355,11 +580,13 @@ def _builtins_broadcast( NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) -def _astype(field: NdArrayField, type_: type) -> NdArrayField: - return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) +def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdArrayField: + if isinstance(field, NdArrayField): + return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) + raise AssertionError("This is the NdArrayField implementation of `fbuiltins.astype`.") -NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field +NdArrayField.register_builtin_func(fbuiltins.astype, _astype) def _get_slices_from_domain_slice( diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 107415eb06..7572040e13 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,7 +32,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, common +from gt4py.next import allocators as next_allocators, common, embedded as next_embedded from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import ( dialect_ast_enums, @@ -290,8 +290,8 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." ) ) - - self.definition(*rewritten_args, **kwargs) + with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + ctx.run(self.definition, *rewritten_args, **kwargs) return ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) @@ -686,6 +686,9 @@ def __call__( offset_provider = kwargs.pop("offset_provider", None) if self.backend is not None: # "out" and "offset_provider" -> field_operator as program + # When backend is None, we are in embedded execution and for now + # we disable the program generation since it would involve generating + # Python source code from a PAST node. args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given # deduce argument types @@ -700,10 +703,20 @@ def __call__( *args, out, offset_provider=offset_provider, **kwargs ) else: - # "out" -> field_operator called from program in embedded execution - # TODO(egparedes): put offset_provider in ctxt var here when implementing remap + # "out" -> field_operator called from program in embedded execution or + # field_operator called directly from Python in embedded execution domain = kwargs.pop("domain", None) - res = self.definition(*args, **kwargs) + if not next_embedded.context.within_context(): + # field_operator from Python in embedded execution + with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + res = ctx.run(self.definition, *args, **kwargs) + else: + # field_operator from program in embedded execution (offset_provicer is already set) + assert ( + offset_provider is None + or next_embedded.context.offset_provider.get() is offset_provider + ) + res = self.definition(*args, **kwargs) _tuple_assign_field( out, res, domain=None if domain is None else common.domain(domain) ) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 706b6a4606..8230e35a35 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -13,28 +13,19 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses +import functools import inspect from builtins import bool, float, int, tuple -from typing import ( - Any, - Callable, - Generic, - Optional, - ParamSpec, - Tuple, - TypeAlias, - TypeVar, - Union, - cast, -) +from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np from numpy import float32, float64, int32, int64 +import gt4py.next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import common -from gt4py.next.common import Dimension, Field # direct import for TYPE_BUILTINS -from gt4py.next.ffront.experimental import as_offset # noqa F401 +from gt4py.next import common, embedded +from gt4py.next.common import Dimension, Field # noqa: F401 # direct import for TYPE_BUILTINS +from gt4py.next.ffront.experimental import as_offset # noqa: F401 from gt4py.next.iterator import runtime from gt4py.next.type_system import type_specifications as ts @@ -43,8 +34,8 @@ PYTHON_TYPE_BUILTIN_NAMES = [t.__name__ for t in PYTHON_TYPE_BUILTINS] TYPE_BUILTINS = [ - Field, - Dimension, + common.Field, + common.Dimension, int32, int64, float32, @@ -214,10 +205,10 @@ def where( @BuiltInFunction def astype( - value: Field | core_defs.ScalarT | Tuple, + value: common.Field | core_defs.ScalarT | Tuple, type_: type, /, -) -> Field | core_defs.ScalarT | Tuple: +) -> common.Field | core_defs.ScalarT | Tuple: if isinstance(value, tuple): return tuple(astype(v, type_) for v in value) # default implementation for scalars, Fields are handled via dispatch @@ -324,7 +315,10 @@ def impl( class FieldOffset(runtime.Offset): source: common.Dimension target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] - connectivity: Optional[Any] = None # TODO + + @functools.cached_property + def _cache(self) -> dict: + return {} def __post_init__(self): if len(self.target) == 2 and self.target[1].kind != common.DimensionKind.LOCAL: @@ -332,3 +326,51 @@ def __post_init__(self): def __gt_type__(self): return ts.OffsetType(source=self.source, target=self.target) + + def __getitem__(self, offset: int) -> common.ConnectivityField: + """Serve as a connectivity factory.""" + assert isinstance(self.value, str) + current_offset_provider = embedded.context.offset_provider.get(None) + assert current_offset_provider is not None + offset_definition = current_offset_provider[self.value] + + connectivity: common.ConnectivityField + if isinstance(offset_definition, common.Dimension): + connectivity = common.CartesianConnectivity(offset_definition, offset) + elif isinstance( + offset_definition, gtx.NeighborTableOffsetProvider + ) or common.is_connectivity_field(offset_definition): + unrestricted_connectivity = self.as_connectivity_field() + assert unrestricted_connectivity.domain.ndim > 1 + named_index = (self.target[-1], offset) + connectivity = unrestricted_connectivity[named_index] + else: + raise NotImplementedError() + + return connectivity + + def as_connectivity_field(self): + """Convert to connectivity field using the offset providers in current embedded execution context.""" + assert isinstance(self.value, str) + current_offset_provider = embedded.context.offset_provider.get(None) + assert current_offset_provider is not None + offset_definition = current_offset_provider[self.value] + + cache_key = id(offset_definition) + if (connectivity := self._cache.get(cache_key, None)) is None: + if common.is_connectivity_field(offset_definition): + connectivity = offset_definition + elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): + assert not offset_definition.has_skip_values + connectivity = gtx.as_connectivity( + domain=self.target, + codomain=self.source, + data=offset_definition.table, + dtype=offset_definition.index_type, + ) + else: + raise NotImplementedError() + + self._cache[cache_key] = connectivity + + return connectivity diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 9000b00d8f..b02d6c8d72 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -51,8 +51,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping -from gt4py.next import common +from gt4py.next import common, embedded as next_embedded from gt4py.next.embedded import exceptions as embedded_exceptions +from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime @@ -60,7 +61,7 @@ # Atoms -Tag: TypeAlias = str +Tag: TypeAlias = common.Tag ArrayIndex: TypeAlias = slice | common.IntIndex ArrayIndexOrIndices: TypeAlias = ArrayIndex | tuple[ArrayIndex, ...] @@ -129,8 +130,8 @@ def mapped_index( # Offsets OffsetPart: TypeAlias = Tag | common.IntIndex CompleteOffset: TypeAlias = tuple[Tag, common.IntIndex] -OffsetProviderElem: TypeAlias = common.Dimension | common.Connectivity -OffsetProvider: TypeAlias = dict[Tag, OffsetProviderElem] +OffsetProviderElem: TypeAlias = common.OffsetProviderElem +OffsetProvider: TypeAlias = common.OffsetProvider # Positions SparsePositionEntry = list[int] @@ -195,9 +196,9 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: #: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[range] = cvars.ContextVar("column_range") +column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range #: Offset provider dict in the current closure execution context. -offset_provider_cvar: cvars.ContextVar[OffsetProvider] = cvars.ContextVar("offset_provider") +offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider class Column(np.lib.mixins.NDArrayOperatorsMixin): @@ -1060,6 +1061,10 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override def domain(self) -> common.Domain: return common.Domain((self._dimension, common.UnitRange.infinity())) + @property + def codomain(self) -> type[core_defs.int32]: + return core_defs.int32 + @property def dtype(self) -> core_defs.Int32DType: return core_defs.Int32DType() @@ -1071,7 +1076,7 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() - def remap(self, index_field: common.Field) -> common.Field: + def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1179,6 +1184,10 @@ def domain(self) -> common.Domain: def dtype(self) -> core_defs.DType[core_defs.ScalarT]: return core_defs.dtype(type(self._value)) + @property + def codomain(self) -> type[core_defs.ScalarT]: + return self.dtype.scalar_type + @property def ndarray(self) -> core_defs.NDArrayObject: raise AttributeError("Cannot get `ndarray` of an infinite Field.") @@ -1186,7 +1195,7 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() - def remap(self, index_field: common.Field) -> common.Field: + def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index ef30a61687..a8a508b2fb 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -120,8 +120,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ - (USES_CARTESIAN_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_UNSTRUCTURED_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 7ef724ee2f..81f216397b 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -72,7 +72,7 @@ C2EDim = gtx.Dimension("C2E", kind=common.DimensionKind.LOCAL) V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) -C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim)) +C2E = gtx.FieldOffset("C2E", source=Edge, target=(Cell, C2EDim)) ScalarValue: TypeAlias = core_defs.Scalar FieldValue: TypeAlias = gtx.Field diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 01c78cf950..1537c01642 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -209,7 +209,7 @@ def reduction_setup(): inp=gtx.as_field([Edge], np.arange(num_edges, dtype=np.int32)), out=gtx.as_field([Vertex], np.zeros([num_vertices], dtype=np.int32)), offset_provider={ - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4, has_skip_values=False), "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False), "C2V": gtx.NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False), "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1f3b54d6f0..cf273a4524 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -992,6 +992,10 @@ def fieldop_where_k_offset( ) -> cases.IKField: return where(k_index > 0, inp(Koff[-1]), 2) + @gtx.program + def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): + fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) + inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() k_index = cases.allocate( cartesian_case, fieldop_where_k_offset, "k_index", strategy=cases.IndexInitializer() @@ -1000,7 +1004,7 @@ def fieldop_where_k_offset( ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), 2) - cases.verify(cartesian_case, fieldop_where_k_offset, inp, k_index, out=out, ref=ref) + cases.verify(cartesian_case, prog, inp, k_index, out=out[:, 1:], ref=ref[:, 1:]) def test_undefined_symbols(cartesian_case): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 8b4cedd98b..130f6bd29c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -28,7 +28,7 @@ ) -pytestmark = pytest.mark.uses_unstructured_shift +pytestmark = [pytest.mark.uses_unstructured_shift, pytest.mark.uses_scan] Cell = gtx.Dimension("Cell") diff --git a/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py b/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py new file mode 100644 index 0000000000..335a08571f --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +import gt4py.next as gtx + + +IDim = gtx.Dimension("IDim") +IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) + + +@gtx.field_operator +def fop( + a: gtx.Field[[IDim], gtx.float64], b: gtx.Field[[IDim], gtx.float64] +) -> gtx.Field[[IDim], gtx.float64]: + return a(IOff[1]) + b + + +@gtx.program +def prog( + a: gtx.Field[[IDim], gtx.float64], + b: gtx.Field[[IDim], gtx.float64], + out: gtx.Field[[IDim], gtx.float64], +): + fop(a, b, out=out) + + +def test_basic(): + a = gtx.as_field([(IDim, gtx.common.UnitRange(1, 5))], np.asarray([0.0, 1.0, 2.0, 3.0])) + b = gtx.as_field([(IDim, gtx.common.UnitRange(0, 4))], np.asarray([0.0, 1.0, 2.0, 3.0])) + out = gtx.as_field([(IDim, gtx.common.UnitRange(0, 4))], np.asarray([0.0, 0.0, 0.0, 0.0])) + + prog(a, b, out, offset_provider={"IOff": IDim}) + assert out.domain == b.domain + assert np.allclose(out.ndarray, a.ndarray + b.ndarray) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 436e672cc5..2b78eb9114 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -20,7 +20,7 @@ import numpy as np import pytest -from gt4py.next import common, constructors +from gt4py.next import common, embedded from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice @@ -70,12 +70,17 @@ def unary_logical_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): +def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=None): if not dtype: dtype = nd_array_implementation.float32 + buffer = nd_array_implementation.asarray(lst, dtype=dtype) + if domain is None: + domain = tuple( + (common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape) + ) return common.field( - nd_array_implementation.asarray(lst, dtype=dtype), - domain={common.Dimension("foo"): (0, len(lst))}, + buffer, + domain=domain, ) @@ -277,6 +282,59 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: assert np.allclose(result.ndarray, expected) +def test_remap_implementation(): + V = Dimension("V") + E = Dimension("E") + + V_START, V_STOP = 2, 7 + E_START, E_STOP = 0, 10 + v_field = common.field( + -0.1 * np.arange(V_START, V_STOP), + domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), + ) + e2v_conn = common.connectivity( + np.arange(E_START, E_STOP), + domain=common.Domain( + dims=(E,), + ranges=[ + UnitRange(E_START, E_STOP), + ], + ), + codomain=V, + ) + + result = v_field.remap(e2v_conn) + expected = common.field( + -0.1 * np.arange(V_START, V_STOP), + domain=common.Domain(dims=(E,), ranges=(UnitRange(V_START, V_STOP),)), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + +def test_cartesian_remap_implementation(): + V = Dimension("V") + E = Dimension("E") + + V_START, V_STOP = 2, 7 + OFFSET = 2 + v_field = common.field( + -0.1 * np.arange(V_START, V_STOP), + domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), + ) + v2_conn = common.connectivity(OFFSET, V) + + result = v_field.remap(v2_conn) + expected = common.field( + v_field.ndarray, + domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START - OFFSET, V_STOP - OFFSET),)), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + @pytest.mark.parametrize( "new_dims,field,expected_domain", [ diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 84008eb99c..da63536953 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -21,6 +21,7 @@ DimensionKind, Domain, Infinity, + NamedRange, UnitRange, domain, named_range, @@ -319,6 +320,134 @@ def test_domain_dims_ranges_length_mismatch(): Domain(dims=dims, ranges=ranges) +def test_domain_dim_index(): + dims = [Dimension("X"), Dimension("Y"), Dimension("Z")] + ranges = [UnitRange(0, 1), UnitRange(0, 1), UnitRange(0, 1)] + domain = Domain(dims=dims, ranges=ranges) + + domain.dim_index(Dimension("Y")) == 1 + + domain.dim_index(Dimension("Foo")) == None + + +def test_domain_pop(): + dims = [Dimension("X"), Dimension("Y"), Dimension("Z")] + ranges = [UnitRange(0, 1), UnitRange(0, 1), UnitRange(0, 1)] + domain = Domain(dims=dims, ranges=ranges) + + domain.pop(Dimension("X")) == Domain(dims=dims[1:], ranges=ranges[1:]) + + domain.pop(0) == Domain(dims=dims[1:], ranges=ranges[1:]) + + domain.pop(-1) == Domain(dims=dims[:-1], ranges=ranges[:-1]) + + +@pytest.mark.parametrize( + "index, named_ranges, domain, expected", + [ + # Valid index and named ranges + ( + 0, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("X"), UnitRange(100, 110)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + ), + ( + 1, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("X"), UnitRange(100, 110)), + (Dimension("K"), UnitRange(0, 10)), + ), + ), + ( + -1, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("X"), UnitRange(100, 110)), + ), + ), + ( + Dimension("J"), + [(Dimension("X"), UnitRange(100, 110)), (Dimension("Z"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("X"), UnitRange(100, 110)), + (Dimension("Z"), UnitRange(100, 110)), + (Dimension("K"), UnitRange(0, 10)), + ), + ), + # Invalid indices + ( + 3, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + IndexError, + ), + ( + -4, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + IndexError, + ), + ( + Dimension("Foo"), + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + ValueError, + ), + ], +) +def test_domain_replace(index, named_ranges, domain, expected): + if expected is ValueError: + with pytest.raises(ValueError): + domain.replace(index, *named_ranges) + elif expected is IndexError: + with pytest.raises(IndexError): + domain.replace(index, *named_ranges) + else: + new_domain = domain.replace(index, *named_ranges) + assert new_domain == expected + + def dimension_promotion_cases() -> ( list[tuple[list[list[Dimension]], list[Dimension] | None, None | Pattern]] ): From 5a409c560444b73b8807e25552837b10d33cf8f7 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 23 Nov 2023 13:43:17 +0100 Subject: [PATCH 41/67] feat[next]: DaCe backend - enable GPU test execution (#1360) This commit enables test execution on the DaCe GPU backend: * Small fix in DaCe SDFG generation for GPU execution. The fix is about handling of in/out fields, for which the program argument is copied to a transient array. We need to use for the transient array the same storage as the program argument (i.e. gpu storage), otherwise code generation will throw an error because of mixed storage for inputs to the closure map. * Minor code refactoring (test exclusion matrix, dace backend processor interface) * Cleanup test exclusion matrix (some left-overs after rebase of previous dace PR) Note that 3 testcases are disabled because the fix needs to be delivered to DaCe repo and a new DaCe release should be provided, in order to update the GT4Py dependency list. --- .../ADRs/0015-Test_Exclusion_Matrices.md | 4 +- pyproject.toml | 2 + .../runners/dace_iterator/__init__.py | 48 +++++++++-------- .../runners/dace_iterator/itir_to_sdfg.py | 14 ++++- tests/next_tests/exclusion_matrices.py | 54 +++++++++---------- .../ffront_tests/ffront_test_utils.py | 3 ++ .../ffront_tests/test_external_local_field.py | 10 ++++ .../ffront_tests/test_gt4py_builtins.py | 10 ++++ .../ffront_tests/test_math_unary_builtins.py | 12 +---- tests/next_tests/unit_tests/conftest.py | 6 +++ 10 files changed, 101 insertions(+), 62 deletions(-) diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md index 6c6a043560..b338169d61 100644 --- a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -47,10 +47,12 @@ by calling `next_tests.get_processor_id()`, which returns the so-called processo The following backend processors are defined: ```python -DACE = "dace_iterator.run_dace_iterator" +DACE_CPU = "dace_iterator.run_dace_cpu" +DACE_GPU = "dace_iterator.run_dace_gpu" GTFN_CPU = "otf_compile_executor.run_gtfn" GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ``` Following the previous example, the GTFN backend with temporaries does not support yet dynamic offsets in ITIR: diff --git a/pyproject.toml b/pyproject.toml index 041448e17d..2cf4fb12e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -332,12 +332,14 @@ markers = [ 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', + 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', 'uses_lift_expressions: tests that require backend support for lift expressions', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', + 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_sparse_fields: tests that require backend support for sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index e3fba87571..40b6d24b0e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import hashlib +import warnings from typing import Any, Mapping, Optional, Sequence import dace @@ -22,11 +23,11 @@ import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.otf_compile_executor as otf_exec +import gt4py.next.program_processors.processor_interface as ppi from gt4py.next.common import Dimension, Domain, UnitRange, is_field from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.processor_interface import program_executor from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG @@ -94,10 +95,26 @@ def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def _ensure_is_on_device( + connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType +) -> np.typing.NDArray: + if device == dace.dtypes.DeviceType.GPU: + if not isinstance(connectivity_arg, cp.ndarray): + warnings.warn( + "Copying connectivity to device. For performance make sure connectivity is provided on device." + ) + return cp.asarray(connectivity_arg) + return connectivity_arg + + def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]] + neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]], + device: dace.dtypes.DeviceType, ) -> dict[str, Any]: - return {connectivity_identifier(offset): table.table for offset, table in neighbor_tables} + return { + connectivity_identifier(offset): _ensure_is_on_device(table.table, device) + for offset, table in neighbor_tables + } def get_shape_args( @@ -167,7 +184,6 @@ def get_cache_id( return m.hexdigest() -@program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: # build parameters auto_optimize = kwargs.get("auto_optimize", False) @@ -182,6 +198,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] + device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) @@ -192,26 +209,16 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: else: # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() - # set array storage for GPU execution - if run_on_gpu: - device = dace.DeviceType.GPU - sdfg._name = f"{sdfg.name}_gpu" - for _, _, array in sdfg.arrays_recursive(): - if not array.transient: - array.storage = dace.dtypes.StorageType.GPU_Global - else: - device = dace.DeviceType.CPU - # run DaCe auto-optimization heuristics if auto_optimize: # TODO Investigate how symbol definitions improve autoopt transformations, # in which case the cache table should take the symbols map into account. symbols: dict[str, int] = {} - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu) # compile SDFG and retrieve SDFG program sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" @@ -226,7 +233,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables) + dace_conn_args = get_connectivity_args(neighbor_tables, device) dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) @@ -254,7 +261,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg_program(**expected_args) -@program_executor def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, @@ -267,13 +273,12 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_cpu = otf_exec.OTFBackend( - executor=_run_dace_cpu, + executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"), allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) if cp: - @program_executor def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, @@ -286,12 +291,11 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: else: - @program_executor def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: raise RuntimeError("Missing `cupy` dependency for GPU execution.") run_dace_gpu = otf_exec.OTFBackend( - executor=_run_dace_gpu, + executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"), allocator=next_allocators.StandardGPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 9e9cc4bf29..a7cecf5fad 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -96,17 +96,20 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int + use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTableOffsetProvider], column_axis: Optional[Dimension] = None, + use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} + self.use_gpu_storage = use_gpu_storage def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): @@ -118,7 +121,14 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset else None ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) + storage = ( + dace.dtypes.StorageType.GPU_Global + if self.use_gpu_storage + else dace.dtypes.StorageType.Default + ) + sdfg.add_array( + name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage + ) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -225,6 +235,7 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, + storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( @@ -239,6 +250,7 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, + storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index a8a508b2fb..a6a302e143 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -57,6 +57,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_gpu" class ProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum): @@ -83,9 +84,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Test markers REQUIRES_ATLAS = "requires_atlas" USES_APPLIED_SHIFTS = "uses_applied_shifts" -USES_CAN_DEREF = "uses_can_deref" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" +USES_FLOORDIV = "uses_floordiv" USES_IF_STMTS = "uses_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" USES_LIFT_EXPRESSIONS = "uses_lift_expressions" @@ -111,7 +112,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): "We cannot unroll a reduction on a sparse field only (not clear if it is legal ITIR)" ) # Common list of feature markers to skip -GTFN_SKIP_TEST_LIST = [ +COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), @@ -119,46 +120,45 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] +DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), +] EMBEDDED_SKIP_LIST = [ (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ] +GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), +] #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) BACKEND_SKIP_TEST_MATRIX = { None: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: GTFN_SKIP_TEST_LIST - + [ - (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), - (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST - + [ - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST - + [ - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + [ - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # awaiting dace fix, see https://github.com/spcl/dace/pull/1442 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST, + ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 1537c01642..f8a3f6a975 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -45,6 +45,9 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append(definitions.OptionalProgramBackendId.DACE_CPU) + OPTIONAL_PROCESSORS.append( + pytest.param(definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu) + ), @pytest.fixture( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 05adc63a45..42938e2f4b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,6 +30,16 @@ def test_external_local_field(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e2434d860a..bbbac6c139 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -120,6 +120,16 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 8660ecfdbd..c2ab43773f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -37,7 +37,6 @@ tanh, trunc, ) -from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case @@ -67,17 +66,8 @@ def pow(inp1: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, pow, ref=lambda inp1: inp1**2) +@pytest.mark.uses_floordiv def test_floordiv(cartesian_case): - if cartesian_case.backend in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, - gtfn.run_gtfn_gpu, - ]: - pytest.xfail( - "FloorDiv not yet supported." - ) # see https://github.com/GridTools/gt4py/issues/1136 - @gtx.field_operator def floorDiv(inp1: cases.IField) -> cases.IField: return inp1 // 2 diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 372062d08a..6f91557e46 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -50,6 +50,12 @@ def lift_mode(request): OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append((definitions.OptionalProgramBackendId.DACE_CPU, True)) + # TODO(havogt): update tests to use proper allocation + # OPTIONAL_PROCESSORS.append( + # pytest.param( + # (definitions.OptionalProgramBackendId.DACE_GPU, True), marks=pytest.mark.requires_gpu + # ) + # ), @pytest.fixture( From 6bea0074305d2261844746ee4f801a72a8e1c435 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 24 Nov 2023 12:58:20 +0100 Subject: [PATCH 42/67] fix[next]: Fix for GPU test execution (#1368) This commit removes test_gpu_backend.py Spack build of icon4py was broken because dace is an optional module, not installed in the default environment, and the dace backend is not available in test execution. This caused an ImportError exception in test_gpu_backend.py, because this test is bypassing the test exclusion matrix. The initial proposed fix was to use try/except to handle this case. However, all tests in baseline are already executed on the GPU backends (both GTFN and DaCe), therefore this simple test is no longer needed. --- .../ffront_tests/test_gpu_backend.py | 45 ------------------- 1 file changed, 45 deletions(-) delete mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py deleted file mode 100644 index 7054597831..0000000000 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ /dev/null @@ -1,45 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import pytest - -import gt4py.next as gtx -from gt4py.next import common -from gt4py.next.program_processors.runners import dace_iterator, gtfn - -from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case # noqa: F401 -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 - fieldview_backend, -) - - -@pytest.mark.requires_gpu -@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace_gpu, gtfn.run_gtfn_gpu]) -def test_copy(fieldview_backend): # noqa: F811 # fixtures - import cupy as cp - - @gtx.field_operator(backend=fieldview_backend) - def testee(a: cases.IJKField) -> cases.IJKField: - return a - - domain = { - cases.IDim: common.unit_range(3), - cases.JDim: common.unit_range(4), - cases.KDim: common.unit_range(5), - } - inp_field = gtx.full(domain, fill_value=3, allocator=fieldview_backend, dtype=cp.int32) - out_field = gtx.zeros(domain, allocator=fieldview_backend, dtype=cp.int32) - testee(inp_field, out=out_field, offset_provider={}) - assert cp.allclose(inp_field.ndarray, out_field.ndarray) From 1e486f2cd8caebcea5f3cc342a9727fc8a9f1d03 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 27 Nov 2023 16:01:27 +0100 Subject: [PATCH 43/67] test[next]: Fix warnings that cause Spack to crash (#1369) Solves warning about invalid escape sequence in some regex strings. These warnings cause spack build to crash. --- .../feature_tests/ffront_tests/test_program.py | 4 ++-- .../feature_tests/ffront_tests/test_type_deduction.py | 4 ++-- tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py | 2 +- tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index a0f69f332c..4c0613a33c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -228,8 +228,8 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program(inp, out, offset_provider={}) msgs = [ - "- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," - " but got `Field\[\[JDim\], float64\]`.", + r"- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," + r" but got `Field\[\[JDim\], float64\]`.", ] for msg in msgs: assert re.search(msg, exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index dfa710e038..d1a5f24f79 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -299,7 +299,7 @@ def callable_type_info_cases(): [ts.TupleType(types=[float_type, field_type])], {}, [ - "Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" + r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" ], ts.VoidType(), ), @@ -308,7 +308,7 @@ def callable_type_info_cases(): [int_type], {}, [ - "Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" + r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" ], ts.VoidType(), ), diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 6e617f77a2..1d1a1efad4 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -206,7 +206,7 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): assert exc_info.match("Invalid call to `domain_format_3`") assert ( - re.search("Missing required keyword argument\(s\) `out`.", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) is not None ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index e56dc85322..c4fe30c596 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -187,6 +187,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): # is not None # ) assert ( - re.search("Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) is not None ) From 5a912cf1d97e3c5b3f555f1c104b8650df282263 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 28 Nov 2023 12:52:16 +0100 Subject: [PATCH 44/67] fix[next]: DaCe backend - symbol propagation in lambda scope (#1367) Full-inlining of ITIR lift operator can result in nested lambda calls, which are translated to nested SDFGs in DaCe backend. The problem was that inner lambda SDFGs were not inheriting the symbols in scope from the parent SDFG, instead the DaCe backend was only mapping the lambda arguments. The solution implemented in this PR is to run a pre-pass to discover all symbols used in the nested lambdas, and propagate the required data containers from outer to inner SDFG. This PR also contains some cleanup of the scan visitor. The overall goal is to rely as much as possible on the visitor for itir.FunCall to generate the scan body. --- .../runners/dace_iterator/itir_to_sdfg.py | 188 +++++------ .../runners/dace_iterator/itir_to_tasklet.py | 295 +++++++++++------- 2 files changed, 263 insertions(+), 220 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a7cecf5fad..94878fd46d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -24,9 +24,10 @@ from .itir_to_tasklet import ( Context, - IteratorExpr, + GatherOutputSymbolsPass, PythonTaskletCodegen, SymbolExpr, + TaskletExpr, ValueExpr, closure_to_tasklet_sdfg, is_scan, @@ -136,8 +137,13 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset self.storage_types[name] = type_ def get_output_nodes( - self, closure: itir.StencilClosure, context: Context + self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState ) -> dict[str, dace.nodes.AccessNode]: + # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes + output_symbols_pass = GatherOutputSymbolsPass(sdfg, state, self.node_types) + output_symbols_pass.visit(closure.output) + # Visit output node again to generate the corresponding tasklet + context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -212,19 +218,16 @@ def visit_StencilClosure( closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - neighbor_tables = filter_neighbor_tables(self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + neighbor_tables = filter_neighbor_tables(self.offset_provider) + connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_nodes = self.get_output_nodes(node, closure_ctx) + output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) output_names = [k for k, _ in output_nodes.items()] # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. input_transients_mapping = {} - for name in [*input_names, *conn_names, *output_names]: + for name in [*input_names, *connectivity_names, *output_names]: if name in closure_sdfg.arrays: assert name in input_names and name in output_names # In case of closures with in/out fields, there is risk of race condition @@ -268,6 +271,7 @@ def visit_StencilClosure( ) # Update symbol table and get output domain of the closure + program_arg_syms: dict[str, TaskletExpr] = {} for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): dtype = as_dace_type(type_) @@ -285,10 +289,11 @@ def visit_StencilClosure( program_arg_syms[name] = value else: program_arg_syms[name] = SymbolExpr(name, dtype) + closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters - input_access_names = [ + input_local_names = [ input_transients_mapping[input_name] if input_name in input_transients_mapping else input_name @@ -297,9 +302,9 @@ def visit_StencilClosure( for input_name in input_names ] input_memlets = [ - create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names + create_memlet_full(name, closure_sdfg.arrays[name]) + for name in [*input_local_names, *connectivity_names] ] - conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] # create and write to transient that is then copied back to actual output array to avoid aliasing of # same memory in nested SDFG with different names @@ -340,18 +345,18 @@ def visit_StencilClosure( for output_name in output_connectors_mapping.values() ] - input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)} - conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)} + input_mapping = { + param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) + } + output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - array_mapping = {**input_mapping, **conn_mapping} - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) + symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, map_ranges=map_ranges or {"__dummy": "0"}, - inputs=array_mapping, + inputs=input_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, @@ -376,7 +381,7 @@ def visit_StencilClosure( closure_state.remove_edge(edge) access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - return closure_sdfg, input_field_names + conn_names, output_names + return closure_sdfg, input_field_names + connectivity_names, output_names def _visit_scan_stencil_closure( self, @@ -422,6 +427,23 @@ def _visit_scan_stencil_closure( lambda_state = scan_sdfg.add_state("lambda_compute") end_state = scan_sdfg.add_state("end") + # the carry value of the scan operator exists only in the scope of the scan sdfg + scan_carry_name = unique_var_name() + scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) + + # tasklet for initialization of carry + carry_init_tasklet = start_state.add_tasklet( + "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" + ) + start_state.add_edge( + carry_init_tasklet, + "__result", + start_state.add_access(scan_carry_name), + None, + dace.Memlet.simple(scan_carry_name, "0"), + ) + + # TODO(edopao): replace state machine with dace loop construct scan_sdfg.add_loop( start_state, lambda_state, @@ -434,7 +456,7 @@ def _visit_scan_stencil_closure( increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1", ) - # add access nodes to SDFG for inputs + # add storage to scan SDFG for inputs for name in [*input_names, *connectivity_names]: assert name not in scan_sdfg.arrays if isinstance(self.storage_types[name], ts.FieldType): @@ -448,116 +470,76 @@ def _visit_scan_stencil_closure( scan_sdfg.add_scalar( name, dtype=as_dace_type(cast(ts.ScalarType, self.storage_types[name])) ) + # add storage to scan SDFG for output + scan_sdfg.add_array( + output_name, + shape=(array_table[node.output.id].shape[scan_dim_index],), + strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), + dtype=array_table[node.output.id].dtype, + ) + # implement the lambda function as a nested SDFG that computes a single item in the scan dimension + lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} + input_arrays = [(scan_carry_name, scan_dtype)] + [ + (name, self.storage_types[name]) for name in input_names + ] connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - - # implement the lambda closure as a nested SDFG that computes a single item of the map domain - lambda_context, lambda_inputs, lambda_outputs = closure_to_tasklet_sdfg( + lambda_context, lambda_outputs = closure_to_tasklet_sdfg( node, self.offset_provider, - {}, - [], + lambda_domain, + input_arrays, connectivity_arrays, self.node_types, ) + lambda_input_names = [name for name, _ in input_arrays] + lambda_output_names = [connector.value.data for connector in lambda_outputs] + + input_memlets = [ + create_memlet_full(name, scan_sdfg.arrays[name]) for name in lambda_input_names + ] connectivity_memlets = [ create_memlet_full(name, scan_sdfg.arrays[name]) for name in connectivity_names ] + input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} connectivity_mapping = { param: arg for param, arg in zip(connectivity_names, connectivity_memlets) } - - lambda_input_names = [inner_name for inner_name, _ in lambda_inputs] - symbol_mapping = map_nested_sdfg_symbols( - scan_sdfg, lambda_context.body, connectivity_mapping - ) + array_mapping = {**input_mapping, **connectivity_mapping} + symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) scan_inner_node = lambda_state.add_nested_sdfg( lambda_context.body, parent=scan_sdfg, inputs=set(lambda_input_names) | set(connectivity_names), - outputs={connector.value.label for connector in lambda_outputs}, + outputs=set(lambda_output_names), symbol_mapping=symbol_mapping, ) - # the carry value of the scan operator exists in the scope of the scan sdfg - scan_carry_name = unique_var_name() - lambda_carry_name, _ = lambda_inputs[0] - scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) - - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" - ) - carry_node1 = start_state.add_access(scan_carry_name) - start_state.add_edge( - carry_init_tasklet, - "__result", - carry_node1, - None, - dace.Memlet.simple(scan_carry_name, "0"), - ) - - carry_node2 = lambda_state.add_access(scan_carry_name) - lambda_state.add_memlet_path( - carry_node2, - scan_inner_node, - memlet=dace.Memlet.simple(scan_carry_name, "0"), - src_conn=None, - dst_conn=lambda_carry_name, - ) - - # connect access nodes to lambda inputs - for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - if isinstance(self.storage_types[data_name], ts.FieldType): - memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain)) - else: - memlet = dace.Memlet.simple(data_name, "0") - lambda_state.add_memlet_path( - lambda_state.add_access(data_name), - scan_inner_node, - memlet=memlet, - src_conn=None, - dst_conn=inner_name, - ) - - for inner_name, memlet in connectivity_mapping.items(): - access_node = lambda_state.add_access(inner_name) - lambda_state.add_memlet_path( - access_node, - scan_inner_node, - memlet=memlet, - src_conn=None, - dst_conn=inner_name, - propagate=True, - ) + # connect scan SDFG to lambda inputs + for name, memlet in array_mapping.items(): + access_node = lambda_state.add_access(name) + lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] - assert len(lambda_outputs) == 1 - # connect lambda output to access node - for lambda_connector, data_name in zip(lambda_outputs, output_names): - scan_sdfg.add_array( - data_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - offset=(array_table[node.output.id].offset[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - lambda_state.add_memlet_path( + assert len(lambda_output_names) == 1 + # connect lambda output to scan SDFG + for name, connector in zip(output_names, lambda_output_names): + lambda_state.add_edge( scan_inner_node, - lambda_state.add_access(data_name), - memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), - src_conn=lambda_connector.value.label, - dst_conn=None, + connector, + lambda_state.add_access(name), + None, + dace.Memlet.simple(name, f"i_{scan_dim}"), ) # add state to scan SDFG to update the carry value at each loop iteration lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - result_node = lambda_update_state.add_access(output_names[0]) - carry_node3 = lambda_update_state.add_access(scan_carry_name) lambda_update_state.add_memlet_path( - result_node, - carry_node3, + lambda_update_state.add_access(output_name), + lambda_update_state.add_access(scan_carry_name), memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) @@ -586,14 +568,14 @@ def _visit_parallel_stencil_closure( index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} input_arrays = [(name, self.storage_types[name]) for name in input_names] - conn_arrays = [(array_table[name], name) for name in conn_names] + connectivity_arrays = [(array_table[name], name) for name in conn_names] - context, _, results = closure_to_tasklet_sdfg( + context, results = closure_to_tasklet_sdfg( node, self.offset_provider, index_domain, input_arrays, - conn_arrays, + connectivity_arrays, self.node_types, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 5b240ea2b7..da54f9be14 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -11,11 +11,10 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - import dataclasses import itertools from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, TypeAlias, cast import dace import numpy as np @@ -23,6 +22,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen +from gt4py import eve from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider @@ -151,11 +151,15 @@ class IteratorExpr: dimensions: list[str] +# Union of possible expression types +TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr + + @dataclasses.dataclass class Context: body: dace.SDFG state: dace.SDFGState - symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] + symbol_map: dict[str, TaskletExpr] # if we encounter a reduction node, the reduction state needs to be pushed to child nodes reduce_limit: int reduce_wcr: Optional[str] @@ -164,13 +168,15 @@ def __init__( self, body: dace.SDFG, state: dace.SDFGState, - symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr], + symbol_map: dict[str, TaskletExpr], + reduce_limit: int = 0, + reduce_wcr: Optional[str] = None, ): self.body = body self.state = state self.symbol_map = symbol_map - self.reduce_limit = 0 - self.reduce_wcr = None + self.reduce_limit = reduce_limit + self.reduce_wcr = reduce_wcr def builtin_neighbors( @@ -350,6 +356,104 @@ def builtin_undefined(*args: Any) -> Any: } +class GatherLambdaSymbolsPass(eve.NodeVisitor): + _sdfg: dace.SDFG + _state: dace.SDFGState + _symbol_map: dict[str, TaskletExpr] + _parent_symbol_map: dict[str, TaskletExpr] + + def __init__( + self, + sdfg, + state, + parent_symbol_map, + ): + self._sdfg = sdfg + self._state = state + self._symbol_map = {} + self._parent_symbol_map = parent_symbol_map + + @property + def symbol_refs(self): + """Dictionary of symbols referenced from the lambda expression.""" + return self._symbol_map + + def _add_symbol(self, param, arg): + if isinstance(arg, ValueExpr): + # create storage in lambda sdfg + self._sdfg.add_scalar(param, dtype=arg.dtype) + # update table of lambda symbol + self._symbol_map[param] = ValueExpr(self._state.add_access(param), arg.dtype) + elif isinstance(arg, IteratorExpr): + # create storage in lambda sdfg + ndims = len(arg.dimensions) + shape = tuple( + dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) + ) + strides = tuple( + dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) + ) + self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) + index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} + for _, index_name in index_names.items(): + self._sdfg.add_scalar(index_name, dtype=dace.int64) + # update table of lambda symbol + field = self._state.add_access(param) + indices = { + dim: self._state.add_access(index_arg) for dim, index_arg in index_names.items() + } + self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) + else: + assert isinstance(arg, SymbolExpr) + self._symbol_map[param] = arg + + def visit_SymRef(self, node: itir.SymRef): + name = str(node.id) + if name in self._parent_symbol_map and name not in self._symbol_map: + arg = self._parent_symbol_map[name] + self._add_symbol(name, arg) + + def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): + if args is not None: + assert len(node.params) == len(args) + for param, arg in zip(node.params, args): + self._add_symbol(str(param.id), arg) + self.visit(node.expr) + + +class GatherOutputSymbolsPass(eve.NodeVisitor): + _sdfg: dace.SDFG + _state: dace.SDFGState + _node_types: dict[int, next_typing.Type] + _symbol_map: dict[str, TaskletExpr] + + @property + def symbol_refs(self): + """Dictionary of symbols referenced from the output expression.""" + return self._symbol_map + + def __init__( + self, + sdfg, + state, + node_types, + ): + self._sdfg = sdfg + self._state = state + self._node_types = node_types + self._symbol_map = {} + + def visit_SymRef(self, node: itir.SymRef): + param = str(node.id) + if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: + node_type = self._node_types[id(node)] + assert isinstance(node_type, Val) + access_node = self._state.add_access(param) + self._symbol_map[param] = ValueExpr( + access_node, dtype=itir_type_as_dace_type(node_type.dtype) + ) + + class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): offset_provider: dict[str, Any] context: Context @@ -369,7 +473,7 @@ def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): raise NotImplementedError() def visit_Lambda( - self, node: itir.Lambda, args: Sequence[ValueExpr | SymbolExpr] + self, node: itir.Lambda, args: Sequence[TaskletExpr] ) -> tuple[ Context, list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], @@ -377,62 +481,38 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = filter_neighbor_tables(self.offset_provider) - param_names = [str(p.id) for p in node.params] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - - assert len(param_names) == len(args) - symbols = { - **{param: arg for param, arg in zip(param_names, args)}, - } - - # Create the SDFG for the function's body - prev_context = self.context - context_sdfg = dace.SDFG(func_name) - context_state = context_sdfg.add_state(f"{func_name}_entry", True) - symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} - value: ValueExpr | IteratorExpr - for param, arg in symbols.items(): - if isinstance(arg, ValueExpr): - value = ValueExpr(context_state.add_access(param), arg.dtype) - else: - assert isinstance(arg, IteratorExpr) - field = context_state.add_access(param) - indices = { - dim: context_state.add_access(f"__{param}_i_{dim}") - for dim in arg.indices.keys() - } - value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - symbol_map[param] = value - context = Context(context_sdfg, context_state, symbol_map) - context.reduce_limit = prev_context.reduce_limit - context.reduce_wcr = prev_context.reduce_wcr - self.context = context + connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + + # Create the SDFG for the lambda's body + lambda_sdfg = dace.SDFG(func_name) + lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) - # Add input parameters as arrays + lambda_symbols_pass = GatherLambdaSymbolsPass( + lambda_sdfg, lambda_state, self.context.symbol_map + ) + lambda_symbols_pass.visit(node, args=args) + + # Add for input nodes for lambda symbols inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for name, arg in symbols.items(): - if isinstance(arg, ValueExpr): - dtype = arg.dtype - context.body.add_scalar(name, dtype=dtype) - inputs.append((name, arg)) + for sym, input_node in lambda_symbols_pass.symbol_refs.items(): + arg = next((arg for param, arg in zip(node.params, args) if param.id == sym), None) + if arg: + outer_node = arg else: - assert isinstance(arg, IteratorExpr) - ndims = len(arg.dimensions) - shape = tuple( - dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) - ) - strides = tuple( - dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) - ) - dtype = arg.dtype - context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) - index_names = {dim: f"__{name}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - context.body.add_scalar(index_name, dtype=dace.int64) - inputs.append(((name, index_names), arg)) + # the symbol is not found among lambda arguments, then it is inherited from parent scope + outer_node = self.context.symbol_map[sym] + if isinstance(input_node, IteratorExpr): + assert isinstance(outer_node, IteratorExpr) + index_params = { + dim: index_node.data for dim, index_node in input_node.indices.items() + } + inputs.append(((sym, index_params), outer_node)) + elif isinstance(input_node, ValueExpr): + assert isinstance(outer_node, ValueExpr) + inputs.append((sym, outer_node)) # Add connectivities as arrays - for name in conn_names: + for name in connectivity_names: shape = ( dace.symbol(unique_var_name() + "__shp", dace.int64), dace.symbol(unique_var_name() + "__shp", dace.int64), @@ -441,50 +521,53 @@ def visit_Lambda( dace.symbol(unique_var_name() + "__strd", dace.int64), dace.symbol(unique_var_name() + "__strd", dace.int64), ) - dtype = prev_context.body.arrays[name].dtype - context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) + dtype = self.context.body.arrays[name].dtype + lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + + # Translate the lambda's body in its own context + lambda_context = Context( + lambda_sdfg, + lambda_state, + lambda_symbols_pass.symbol_refs, + reduce_limit=self.context.reduce_limit, + reduce_wcr=self.context.reduce_wcr, + ) + lambda_taskgen = PythonTaskletCodegen(self.offset_provider, lambda_context, self.node_types) - # Translate the function's body results: list[ValueExpr] = [] # We are flattening the returned list of value expressions because the multiple outputs of a lamda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - for expr in flatten_list(self.visit(node.expr)): + for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): result_name = unique_var_name() - self.context.body.add_scalar(result_name, expr.dtype, transient=True) - result_access = self.context.state.add_access(result_name) - self.context.state.add_edge( + lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) + result_access = lambda_state.add_access(result_name) + lambda_state.add_edge( expr.value, None, result_access, None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet.simple(result_access.data, "0", wcr_str=context.reduce_wcr), + dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] - self.context.body.arrays[result.value.data].transient = False + result = lambda_taskgen.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + lambda_sdfg.arrays[result.value.data].transient = False results.append(result) - self.context = prev_context - for node in context.state.nodes(): - if isinstance(node, dace.nodes.AccessNode): - if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0: - context.state.remove_node(node) + # remove isolated access nodes for connectivity arrays not consumed by lambda + for sub_node in lambda_state.nodes(): + if isinstance(sub_node, dace.nodes.AccessNode): + if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: + lambda_state.remove_node(sub_node) - return context, inputs, results + return lambda_context, inputs, results def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - if node.id not in self.context.symbol_map: - acc = self.context.state.add_access(node.id) - node_type = self.node_types[id(node)] - assert isinstance(node_type, Val) - self.context.symbol_map[node.id] = ValueExpr( - value=acc, dtype=itir_type_as_dace_type(node_type.dtype) - ) - value = self.context.symbol_map[node.id] + param = str(node.id) + value = self.context.symbol_map[param] if isinstance(value, (ValueExpr, SymbolExpr)): return [value] return value @@ -952,29 +1035,6 @@ def is_scan(node: itir.Node) -> bool: return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") -def _visit_scan_closure_callable( - node: itir.StencilClosure, - tlet_codegen: PythonTaskletCodegen, -) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]: - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - fun_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) - - args = list(itertools.chain(tlet_codegen.visit(node.output), *tlet_codegen.visit(node.inputs))) - return tlet_codegen.visit(fun_node, args=args) - - -def _visit_closure_callable( - node: itir.StencilClosure, - tlet_codegen: PythonTaskletCodegen, - input_names: Sequence[str], -) -> Sequence[ValueExpr]: - args = [itir.SymRef(id=name) for name in input_names] - fun_node = itir.FunCall(fun=node.stencil, args=args) - - return tlet_codegen.visit(fun_node) - - def closure_to_tasklet_sdfg( node: itir.StencilClosure, offset_provider: dict[str, Any], @@ -982,10 +1042,10 @@ def closure_to_tasklet_sdfg( inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], node_types: dict[int, next_typing.Type], -) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]: +) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") state = body.add_state("tasklet_toplevel_entry") - symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + symbol_map: dict[str, TaskletExpr] = {} idx_accesses = {} for dim, idx in domain.items(): @@ -1023,16 +1083,17 @@ def closure_to_tasklet_sdfg( context = Context(body, state, symbol_map) translator = PythonTaskletCodegen(offset_provider, context, node_types) + args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): - context, inner_inputs, inner_outputs = _visit_scan_closure_callable(node, translator) + stencil = cast(FunCall, node.stencil) + assert isinstance(stencil.args[0], Lambda) + lambda_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) + fun_node = itir.FunCall(fun=lambda_node, args=args) else: - inner_inputs = [] - inner_outputs = _visit_closure_callable( - node, - translator, - [name for name, _ in inputs], - ) - for output in inner_outputs: - context.body.arrays[output.value.data].transient = False + fun_node = itir.FunCall(fun=node.stencil, args=args) + + results = translator.visit(fun_node) + for r in results: + context.body.arrays[r.value.data].transient = False - return context, inner_inputs, inner_outputs + return context, results From 91307b10e2ca1edb76a72cd8a3bebdd66898da60 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Nov 2023 14:49:45 +0100 Subject: [PATCH 45/67] feature[next]: Cache direct field operator call (`as_program`) (#1254) Cache direct calls to field operators by storing the autogenerated programs in a cache. --- src/gt4py/next/ffront/decorator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 7572040e13..67272f88b8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -545,6 +545,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): definition: Optional[types.FunctionType] = None backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None + _program_cache: dict = dataclasses.field(default_factory=dict) @classmethod def from_function( @@ -613,6 +614,13 @@ def as_program( # of arg and kwarg types # TODO(tehrengruber): check foast operator has no out argument that clashes # with the out argument of the program we generate here. + hash_ = eve_utils.content_hash( + (tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items())) + ) + try: + return self._program_cache[hash_] + except KeyError: + pass loc = self.foast_node.location param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching @@ -666,12 +674,13 @@ def as_program( untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) - return Program( + self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, backend=self.backend, grid_type=self.grid_type, ) + return self._program_cache[hash_] def __call__( self, From 4c022866d75c4bdbbff2d24775dcbc40e3a9a0db Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 29 Nov 2023 17:59:54 +0100 Subject: [PATCH 46/67] refactor[next]: Move iterator utils to dedicated module (#1371) Move `gt4py.next.iterator.ir_makers` and `gt4py.next.iterator.transforms.common_pattern_matcher` into a new module named `gt4py.next.iterator.ir_utils`. Just a small refactoring in preparation of #1350. --- src/gt4py/next/ffront/decorator.py | 7 ++++++- src/gt4py/next/ffront/foast_to_itir.py | 3 ++- src/gt4py/next/iterator/ir_utils/__init__.py | 13 +++++++++++++ .../common_pattern_matcher.py | 0 src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py | 0 src/gt4py/next/iterator/pretty_parser.py | 3 ++- src/gt4py/next/iterator/tracing.py | 3 ++- .../next/iterator/transforms/constant_folding.py | 3 ++- src/gt4py/next/iterator/transforms/cse.py | 2 +- src/gt4py/next/iterator/transforms/global_tmps.py | 5 +++-- .../next/iterator/transforms/inline_lambdas.py | 2 +- src/gt4py/next/iterator/transforms/inline_lifts.py | 3 ++- .../next/iterator/transforms/symbol_ref_utils.py | 2 +- src/gt4py/next/iterator/transforms/unroll_reduce.py | 2 +- .../unit_tests/ffront_tests/test_foast_to_itir.py | 3 ++- .../iterator_tests/test_type_inference.py | 3 ++- .../transforms_tests/test_collapse_tuple.py | 4 +--- .../transforms_tests/test_constant_folding.py | 2 +- .../iterator_tests/transforms_tests/test_cse.py | 3 ++- .../transforms_tests/test_global_tmps.py | 3 ++- .../transforms_tests/test_inline_lambdas.py | 2 +- .../transforms_tests/test_inline_lifts.py | 2 +- .../transforms_tests/test_propagate_deref.py | 2 +- .../transforms_tests/test_trace_shifts.py | 3 ++- 24 files changed, 51 insertions(+), 24 deletions(-) create mode 100644 src/gt4py/next/iterator/ir_utils/__init__.py rename src/gt4py/next/iterator/{transforms => ir_utils}/common_pattern_matcher.py (100%) rename src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py (100%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 67272f88b8..e06c651b13 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -53,7 +53,12 @@ from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym +from gt4py.next.iterator.ir_utils.ir_makers import ( + literal_from_value, + promote_to_const_iterator, + ref, + sym, +) from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 816b8581f1..3030c03fd1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -25,7 +25,8 @@ ) from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts diff --git a/src/gt4py/next/iterator/ir_utils/__init__.py b/src/gt4py/next/iterator/ir_utils/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/next/iterator/transforms/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py similarity index 100% rename from src/gt4py/next/iterator/transforms/common_pattern_matcher.py rename to src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py diff --git a/src/gt4py/next/iterator/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py similarity index 100% rename from src/gt4py/next/iterator/ir_makers.py rename to src/gt4py/next/iterator/ir_utils/ir_makers.py diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index a541e985ad..2b1c8169fb 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -16,7 +16,8 @@ from lark import lark, lexer as lark_lexer, visitors as lark_visitors -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im GRAMMAR = """ diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index fbe6a2ae82..d1f6bba8d6 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -20,7 +20,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import Node from gt4py.next import common, iterator -from gt4py.next.iterator import builtins, ir_makers as im +from gt4py.next.iterator import builtins from gt4py.next.iterator.ir import ( AxisLiteral, Expr, @@ -34,6 +34,7 @@ Sym, SymRef, ) +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications, type_translation diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index cda422f30d..fa326760b0 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -13,7 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.next.iterator import embedded, ir, ir_makers as im +from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator.ir_utils import ir_makers as im class ConstantFolding(NodeTranslator): diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 672e23c5e7..cc70e11413 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -233,7 +233,7 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_makers as im + >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1b697e0bc..d9d3d18213 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,10 +22,11 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im, type_inference +from gt4py.next.iterator import ir, type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.cse import extract_subexpression from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index fc268f85e3..eac4338345 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -17,7 +17,7 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8d62450e67..d7d8e5e612 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -19,7 +19,8 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1f604d62b9..1c587fb9d6 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -36,7 +36,7 @@ def apply( Count references to given or all symbols in scope. Examples: - >>> import gt4py.next.iterator.ir_makers as im + >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> CountSymbolRefs.apply(expr) {'x': 2, 'y': 2, 'z': 1} diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index e3084eaba5..60a5db7e96 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index dd66beb522..2dd4b91c48 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -29,7 +29,8 @@ from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts, type_translation diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1526e97d74..cacdb7b070 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -15,7 +15,8 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir, ir_makers as im, type_inference as ti +from gt4py.next.iterator import ir, type_inference as ti +from gt4py.next.iterator.ir_utils import ir_makers as im def test_unsatisfiable_constraints(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 736bf04d64..1444b0a64f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytest - -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 5d052b1989..275412a537 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 5d9e0933a7..065095e1c2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -14,7 +14,8 @@ import textwrap from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.cse import ( CommonSubexpressionElimination as CSE, extract_subexpression, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 88f6ed517b..86c3c98c62 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -15,7 +15,8 @@ import gt4py.next as gtx from gt4py.eve.utils import UIDs -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index b9f2ca16a1..88e554f349 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index 1da2b8a044..e1d440044d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lifts import InlineLifts diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index ffbf2c2c8e..e2e29cd4db 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 2624a17ebd..47db632a5e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts From e564cdc14277155cf820c94f3077b9194986a01d Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 30 Nov 2023 11:14:34 +0100 Subject: [PATCH 47/67] feat[next]: Add option to ITIR transformation to inline lambda args (#1370) Full-inlining of unrolled reduce in DaCe backend requires lambda arguments to be inlined, in order to generate the corresponding taskgraph. This PR adds an option to the existing ITIR transformation InlineLambdas to enable this additional transformation pass, disabled by default. --- .../iterator/transforms/inline_lambdas.py | 14 ++++++++++++ .../next/iterator/transforms/pass_manager.py | 7 +++++- .../runners/dace_iterator/__init__.py | 1 + .../transforms_tests/test_inline_lambdas.py | 22 +++++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index eac4338345..a56ad5cb10 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -29,6 +29,7 @@ def inline_lambda( # noqa: C901 # see todo above opcount_preserving=False, force_inline_lift_args=False, force_inline_trivial_lift_args=False, + force_inline_lambda_args=False, eligible_params: Optional[list[bool]] = None, ): assert isinstance(node.fun, ir.Lambda) @@ -59,6 +60,12 @@ def inline_lambda( # noqa: C901 # see todo above if is_applied_lift(arg) and len(arg.args) == 0: eligible_params[i] = True + # inline lambdas passed as arguments + if force_inline_lambda_args: + for i, arg in enumerate(node.args): + if isinstance(arg, ir.Lambda): + eligible_params[i] = True + if node.fun.params and not any(eligible_params): return node @@ -120,6 +127,8 @@ class InlineLambdas(NodeTranslator): opcount_preserving: bool + force_inline_lambda_args: bool + force_inline_lift_args: bool force_inline_trivial_lift_args: bool @@ -129,6 +138,7 @@ def apply( cls, node: ir.Node, opcount_preserving=False, + force_inline_lambda_args=False, force_inline_lift_args=False, force_inline_trivial_lift_args=False, ): @@ -146,6 +156,8 @@ def apply( opcount_preserving: Preserve the number of operations, i.e. only inline lambda call if the resulting call has the same number of operations. + force_inline_lambda_args: Inline all arguments that are lambda calls, i.e. + `(λ(p) → p(a, a))(λ(x, y) → x+y)` force_inline_lift_args: Inline all arguments that are applied lifts, i.e. `lift(λ(...) → ...)(...)`. force_inline_trivial_lift_args: Inline all arguments that are trivial @@ -154,6 +166,7 @@ def apply( """ return cls( opcount_preserving=opcount_preserving, + force_inline_lambda_args=force_inline_lambda_args, force_inline_lift_args=force_inline_lift_args, force_inline_trivial_lift_args=force_inline_trivial_lift_args, ).visit(node) @@ -164,6 +177,7 @@ def visit_FunCall(self, node: ir.FunCall): return inline_lambda( node, opcount_preserving=self.opcount_preserving, + force_inline_lambda_args=self.force_inline_lambda_args, force_inline_lift_args=self.force_inline_lift_args, force_inline_trivial_lift_args=self.force_inline_trivial_lift_args, ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b0db04eb5f..e2feb79c44 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -79,6 +79,7 @@ def apply_common_transforms( offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, + force_inline_lambda_args=False, unconditionally_collapse_tuples=False, ): if lift_mode is None: @@ -160,6 +161,10 @@ def apply_common_transforms( ir = CommonSubexpressionElimination().visit(ir) ir = MergeLet().visit(ir) - ir = InlineLambdas.apply(ir, opcount_preserving=True) + ir = InlineLambdas.apply( + ir, + opcount_preserving=True, + force_inline_lambda_args=force_inline_lambda_args, + ) return ir diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 40b6d24b0e..d77792664e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -84,6 +84,7 @@ def preprocess_program( fencil_definition = apply_common_transforms( program, common_subexpression_elimination=False, + force_inline_lambda_args=True, lift_mode=lift_mode, offset_provider=offset_provider, unroll_reduce=True, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 88e554f349..bf26889882 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -56,3 +56,25 @@ def test(name, opcount_preserving, testee, expected): inlined = InlineLambdas.apply(testee, opcount_preserving=opcount_preserving) assert inlined == expected + + +def test_inline_lambda_args(): + testee = im.let("reduce_step", im.lambda_("x", "y")(im.plus("x", "y")))( + im.lambda_("a")( + im.call("reduce_step")(im.call("reduce_step")(im.call("reduce_step")("a", 1), 2), 3) + ) + ) + expected = im.lambda_("a")( + im.call(im.lambda_("x", "y")(im.plus("x", "y")))( + im.call(im.lambda_("x", "y")(im.plus("x", "y")))( + im.call(im.lambda_("x", "y")(im.plus("x", "y")))("a", 1), 2 + ), + 3, + ) + ) + inlined = InlineLambdas.apply( + testee, + opcount_preserving=True, + force_inline_lambda_args=True, + ) + assert inlined == expected From 6e133543480f46202f5eb94e9906cb4d2356301a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 30 Nov 2023 09:00:06 -0500 Subject: [PATCH 48/67] fix[cartesian]: DaceIR bridge for DaCe v0.15 (#1373) * Adapt state struct codegen for indexing of dace:* stencil backend * Add new "Default" schedule type for dace <> gt4py schedule mapping * Missing key for make template render * Fix typo * Typo of the typo (tm) --- src/gt4py/cartesian/backend/dace_backend.py | 10 +++++++++- src/gt4py/cartesian/gtc/daceir.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 60da2c36ff..11cd1fa895 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -451,7 +451,7 @@ class DaCeComputationCodegen: const int __I = domain[0]; const int __J = domain[1]; const int __K = domain[2]; - ${name}_t dace_handle; + ${name}_${state_suffix} dace_handle; ${backend_specifics} auto allocator = gt::sid::cached_allocator(&${allocator}); ${"\\n".join(tmp_allocs)} @@ -561,6 +561,13 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S else: omp_threads = "" omp_header = "" + + # Backward compatible state struct name change in DaCe >=0.15.x + try: + dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") + except (KeyError, TypeError): + dace_state_suffix = "t" # old structure name + interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -568,6 +575,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", + state_suffix=dace_state_suffix, ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index dc749a984b..28ebc8cd8e 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -101,6 +101,7 @@ def from_dace_schedule(cls, schedule): dace.ScheduleType.Default: MapSchedule.Default, dace.ScheduleType.Sequential: MapSchedule.Sequential, dace.ScheduleType.CPU_Multicore: MapSchedule.CPU_Multicore, + dace.ScheduleType.GPU_Default: MapSchedule.GPU_Device, dace.ScheduleType.GPU_Device: MapSchedule.GPU_Device, dace.ScheduleType.GPU_ThreadBlock: MapSchedule.GPU_ThreadBlock, }[schedule] From b1f9c9a567e01c14e7236b41fa95f09ae1bb3e2a Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 4 Dec 2023 09:15:16 +0100 Subject: [PATCH 49/67] feat[next][dace]: Support for sparse fields and reductions over lift expressions (#1377) This PR adds support to DaCe backend for sparse fields and reductions over lift expressions. --- .../runners/dace_iterator/itir_to_sdfg.py | 6 +- .../runners/dace_iterator/itir_to_tasklet.py | 181 +++++++++++++----- .../runners/dace_iterator/utility.py | 7 + tests/next_tests/exclusion_matrices.py | 3 - .../ffront_tests/test_execution.py | 1 + 5 files changed, 142 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 94878fd46d..271a79c04b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -150,7 +150,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) - last_state = program_sdfg.add_state("program_entry") + last_state = program_sdfg.add_state("program_entry", True) self.node_types = itir_typing.infer_all(node) # Filter neighbor tables from offset providers. @@ -216,7 +216,7 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") + closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) @@ -423,7 +423,7 @@ def _visit_scan_stencil_closure( scan_sdfg = dace.SDFG(name="scan") # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start") + start_state = scan_sdfg.add_state("start", True) lambda_state = scan_sdfg.add_state("lambda_compute") end_state = scan_sdfg.add_state("end") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index da54f9be14..de18446bbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,6 +18,7 @@ import dace import numpy as np +from dace import subsets from dace.transformation.dataflow import MapFusion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols @@ -39,6 +40,7 @@ filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, + new_array_symbols, unique_name, unique_var_name, ) @@ -131,9 +133,13 @@ def get_reduce_identity_value(op_name_: str, type_: Any): } +# Define type of variables used for field indexing +_INDEX_DTYPE = _TYPE_MAPPING["int64"] + + @dataclasses.dataclass class SymbolExpr: - value: str | dace.symbolic.sympy.Basic + value: dace.symbolic.SymbolicType dtype: dace.typeclass @@ -226,7 +232,7 @@ def builtin_neighbors( outputs={"__result"}, ) idx_name = unique_var_name() - sdfg.add_scalar(idx_name, dace.int64, transient=True) + sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( state.add_access(table_name), me, @@ -283,10 +289,12 @@ def builtin_can_deref( assert shift_callable.fun.id == "shift" iterator = transformer._visit_shift(can_deref_callable) + # this iterator is accessing a neighbor table, so it should return an index + assert iterator.dtype in dace.dtypes.INTEGER_TYPES # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions] + args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " && ".join([f"{v} >= 0" for v in internals]) + expr_code = " and ".join([f"{v} >= 0" for v in internals]) # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( @@ -309,6 +317,26 @@ def builtin_if( return transformer.add_expr_tasklet(expr_args, expr, type_, "if") +def builtin_list_get( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + args = list(itertools.chain(*transformer.visit(node_args))) + assert len(args) == 2 + # index node + assert isinstance(args[0], (SymbolExpr, ValueExpr)) + # 1D-array node + assert isinstance(args[1], ValueExpr) + # source node should be a 1D array + assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1 + + expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] + internals = [ + arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args + ] + expr = f"{internals[1]}[{internals[0]}]" + return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + + def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -340,16 +368,13 @@ def builtin_tuple_get( raise ValueError("Tuple can only be subscripted with compile-time constants") -def builtin_undefined(*args: Any) -> Any: - raise NotImplementedError() - - _GENERAL_BUILTIN_MAPPING: dict[ str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { "can_deref": builtin_can_deref, "cast_": builtin_cast, "if_": builtin_if, + "list_get": builtin_list_get, "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, "tuple_get": builtin_tuple_get, @@ -387,16 +412,11 @@ def _add_symbol(self, param, arg): elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) - shape = tuple( - dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) - ) - strides = tuple( - dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) - ) + shape, strides = new_array_symbols(param, ndims) self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=dace.int64) + self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) # update table of lambda symbol field = self._state.add_access(param) indices = { @@ -513,14 +533,7 @@ def visit_Lambda( # Add connectivities as arrays for name in connectivity_names: - shape = ( - dace.symbol(unique_var_name() + "__shp", dace.int64), - dace.symbol(unique_var_name() + "__shp", dace.int64), - ) - strides = ( - dace.symbol(unique_var_name() + "__strd", dace.int64), - dace.symbol(unique_var_name() + "__strd", dace.int64), - ) + shape, strides = new_array_symbols(name, ndim=2) dtype = self.context.body.arrays[name].dtype lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -542,11 +555,9 @@ def visit_Lambda( result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) result_access = lambda_state.add_access(result_name) - lambda_state.add_edge( + lambda_state.add_nedge( expr.value, - None, result_access, - None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), ) @@ -587,12 +598,13 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: return self._visit_reduce(node) if isinstance(node.fun, itir.SymRef): - if str(node.fun.id) in _MATH_BUILTINS_MAPPING: + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: return self._visit_numeric_builtin(node) - elif str(node.fun.id) in _GENERAL_BUILTIN_MAPPING: + elif builtin_name in _GENERAL_BUILTIN_MAPPING: return self._visit_general_builtin(node) else: - raise NotImplementedError() + raise NotImplementedError(f"{builtin_name} not implemented") return self._visit_call(node) def _visit_call(self, node: itir.FunCall): @@ -697,7 +709,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: for dim in sorted_dims ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in iterator.indices ] internals = [f"{arg.value.data}_v" for arg in args] @@ -726,14 +738,88 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] - else: + elif all([dim in iterator.indices for dim in iterator.dimensions]): + # The deref iterator has index values on all dimensions: the result will be a scalar args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims + ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + else: + # Not all dimensions are included in the deref index list: + # this means the ND-field will be sliced along one or more dimensions and the result will be an array + field_array = self.context.body.arrays[iterator.field.data] + result_shape = tuple( + dim_size + for dim, dim_size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + ) + result_name = unique_var_name() + self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) + result_array = self.context.body.arrays[result_name] + result_node = self.context.state.add_access(result_name) + + deref_connectors = ["_inp"] + [ + f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices + ] + deref_nodes = [iterator.field] + [ + iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices + ] + deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [ + dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] + ] + + # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset + deref_sdfg = dace.SDFG("deref") + deref_sdfg.add_array( + "_inp", field_array.shape, iterator.dtype, strides=field_array.strides + ) + for connector in deref_connectors[1:]: + deref_sdfg.add_scalar(connector, _INDEX_DTYPE) + deref_sdfg.add_array("_out", result_shape, iterator.dtype) + deref_init_state = deref_sdfg.add_state("init", True) + deref_access_state = deref_sdfg.add_state("access") + deref_sdfg.add_edge( + deref_init_state, + deref_access_state, + dace.InterstateEdge( + assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} + ), + ) + # we access the size in source field shape as symbols set on the nested sdfg + source_subset = tuple( + f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + for dim, size in zip(sorted_dims, field_array.shape) + ) + deref_access_state.add_nedge( + deref_access_state.add_access("_inp"), + deref_access_state.add_access("_out"), + dace.Memlet( + data="_out", + subset=subsets.Range.from_array(result_array), + other_subset=",".join(source_subset), + ), + ) + + deref_node = self.context.state.add_nested_sdfg( + deref_sdfg, + self.context.body, + inputs=set(deref_connectors), + outputs={"_out"}, + ) + for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): + self.context.state.add_edge(node, None, deref_node, connector, memlet) + self.context.state.add_edge( + deref_node, + "_out", + result_node, + None, + dace.Memlet.from_array(result_name, result_array), + ) + return [ValueExpr(result_node, iterator.dtype)] + def _split_shift_args( self, args: list[itir.Expr] ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: @@ -760,6 +846,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset_dim = tail[0].value assert isinstance(offset_dim, str) offset_node = self.visit(tail[1])[0] + assert offset_node.dtype in dace.dtypes.INTEGER_TYPES if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] @@ -769,7 +856,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: target_dim = offset_provider.neighbor_axis.value args = [ ValueExpr(connectivity, offset_provider.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] @@ -780,7 +867,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value args = [ - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] @@ -791,14 +878,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = self.offset_provider[offset_dim].value target_dim = shifted_dim args = [ - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "shift" + list(zip(args, internals)), expr, offset_node.dtype, "shift" )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -811,7 +898,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) + self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) offset_node = self.context.state.add_access(offset_var) tasklet_node = self.context.state.add_tasklet( "get_offset", {}, {"__out"}, f"__out = {offset}" @@ -906,7 +993,7 @@ def _visit_reduce(self, node: itir.FunCall): # initialize the reduction result based on type of operation init_value = get_reduce_identity_value(op_name.id, result_dtype) - init_state = self.context.body.add_state_before(self.context.state, "init") + init_state = self.context.body.add_state_before(self.context.state, "init", True) init_tasklet = init_state.add_tasklet( "init_reduce", {}, {"__out"}, f"__out = {init_value}" ) @@ -1044,13 +1131,13 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") - state = body.add_state("tasklet_toplevel_entry") + state = body.add_state("tasklet_toplevel_entry", True) symbol_map: dict[str, TaskletExpr] = {} idx_accesses = {} for dim, idx in domain.items(): name = f"{idx}_value" - body.add_scalar(name, dtype=dace.int64, transient=True) + body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") access = state.add_access(name) idx_accesses[dim] = access @@ -1058,15 +1145,10 @@ def closure_to_tasklet_sdfg( for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) - shape = [ - dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(ndim) - ] - stride = [ - dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim) - ] + shape, strides = new_array_symbols(name, ndim) dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=stride, dtype=dtype) + body.add_array(name, shape=shape, strides=strides, dtype=dtype) field = state.add_access(name) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) @@ -1076,9 +1158,8 @@ def closure_to_tasklet_sdfg( body.add_scalar(name, dtype=dtype) symbol_map[name] = ValueExpr(state.add_access(name), dtype) for arr, name in connectivities: - shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(2)] - stride = [dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(2)] - body.add_array(name, shape=shape, strides=stride, dtype=arr.dtype) + shape, strides = new_array_symbols(name, ndim=2) + body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) translator = PythonTaskletCodegen(offset_provider, context, node_types) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index c17a39ef2d..5ae4676cd7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -166,6 +166,13 @@ def unique_var_name(): return unique_name("__var") +def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: + dtype = dace.int64 + shape = [dace.symbol(unique_name(f"{name}_shp{i}"), dtype) for i in range(ndim)] + strides = [dace.symbol(unique_name(f"{name}_strd{i}"), dtype) for i in range(ndim)] + return shape, strides + + def flatten_list(node_list: list[Any]) -> list[Any]: return list( itertools.chain.from_iterable( diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index a6a302e143..84287e209f 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -122,12 +122,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index cf273a4524..7f37b41383 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -698,6 +698,7 @@ def testee( ) +@pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): From 8a22ba7f1f6f49c2a0065b9b14fb6c417d3bbb78 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 4 Dec 2023 12:27:30 +0100 Subject: [PATCH 50/67] feat[next][dace]: Support for reduce-unroll special case (#1381) During integration of icon4py stencils with the DaCe backend, it was found that reduce-unroll can generate an ITIR containing can_deref on a scalar value. Such expression should always evaluate to true, so it can be evaluated at compile-time. Note that in theory such case could be detected by the ITIR pass, once ITIR type inference is replaced by a new solution. At that time, the solution proposed here should be removed. --- .../runners/dace_iterator/itir_to_tasklet.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index de18446bbe..4fa5ae239c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -289,10 +289,25 @@ def builtin_can_deref( assert shift_callable.fun.id == "shift" iterator = transformer._visit_shift(can_deref_callable) - # this iterator is accessing a neighbor table, so it should return an index - assert iterator.dtype in dace.dtypes.INTEGER_TYPES + # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + if not isinstance(iterator, IteratorExpr): + assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) + # We can always deref a value expression, therefore hard-code `can_deref` to True. + # Returning a SymbolExpr would be preferable, but it requires update to type-checking. + result_name = unique_var_name() + transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) + result_node = transformer.context.state.add_access(result_name) + transformer.context.state.add_edge( + transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"), + "_out", + result_node, + None, + dace.Memlet.simple(result_name, "0"), + ) + return [ValueExpr(result_node, dace.dtypes.bool)] + # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()] + args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] expr_code = " and ".join([f"{v} >= 0" for v in internals]) @@ -833,7 +848,7 @@ def _make_shift_for_rest(self, rest, iterator): fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] ) - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -841,6 +856,12 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) else: iterator = self.visit(node.args[0]) + if not isinstance(iterator, IteratorExpr): + # shift cannot be applied because the argument is not iterable + # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + assert isinstance(iterator, list) and len(iterator) == 1 + assert isinstance(iterator[0], ValueExpr) + return iterator assert isinstance(tail[0], itir.OffsetLiteral) offset_dim = tail[0].value From ebb70b65487d29ef53c1f1fe95d74509c33146aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 4 Dec 2023 12:56:12 +0100 Subject: [PATCH 51/67] fix[next]: Proper calling signature in DaCe (#1374) This commit adds positional arguments to the generated SDFG. It also improves the naming of some automatically generated symbols, such as the shape and stride. --- .../runners/dace_iterator/__init__.py | 31 ++++++++++++++++--- .../runners/dace_iterator/itir_to_sdfg.py | 9 ++++-- .../runners/dace_iterator/utility.py | 15 ++++----- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index d77792664e..735c6b6284 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -185,12 +185,15 @@ def get_cache_id( return m.hexdigest() -def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: +def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Optional[dace.SDFG]: # build parameters auto_optimize = kwargs.get("auto_optimize", False) build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") run_on_gpu = kwargs.get("run_on_gpu", False) + # Return parameter + return_sdfg = kwargs.get("return_sdfg", False) + run_sdfg = kwargs.get("run_sdfg", True) # ITIR parameters column_axis = kwargs.get("column_axis", None) lift_mode = ( @@ -212,6 +215,18 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu) sdfg = sdfg_genenerator.visit(program) + + # All arguments required by the SDFG, regardless if explicit and implicit, are added + # as positional arguments. In the front are all arguments to the Fencil, in that + # order, they are followed by the arguments created by the translation process, + # their order is determined by DaCe and unspecific. + assert len(sdfg.arg_names) == 0 + arg_list = [str(a) for a in program.params] + sig_list = sdfg.signature_arglist(with_types=False) + implicit_args = set(sig_list) - set(arg_list) + call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] + sdfg.arg_names = call_params + sdfg.simplify() # run DaCe auto-optimization heuristics @@ -256,10 +271,16 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: if key in sdfg.signature_arglist(with_types=False) } - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("frontend", "check_args", value=True) - sdfg_program(**expected_args) + if run_sdfg: + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + dace.config.Config.set("frontend", "check_args", value=True) + sdfg_program(**expected_args) + # + + if return_sdfg: + return sdfg + return None def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 271a79c04b..7a6f359771 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -42,6 +42,8 @@ flatten_list, get_sorted_dims, map_nested_sdfg_symbols, + new_array_symbols, + unique_name, unique_var_name, ) @@ -114,10 +116,9 @@ def __init__( def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): - shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] - strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + shape, strides = new_array_symbols(name, len(type_.dims)) offset = ( - [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + [dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))] if has_offset else None ) @@ -130,8 +131,10 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset sdfg.add_array( name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage ) + elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) + else: raise NotImplementedError() self.storage_types[name] = type_ diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 5ae4676cd7..cb14b89e8a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -153,23 +153,20 @@ def add_mapped_nested_sdfg( return nsdfg_node, map_entry, map_exit -_unique_id = 0 - - def unique_name(prefix): - global _unique_id - _unique_id += 1 - return f"{prefix}_{_unique_id}" + unique_id = getattr(unique_name, "_unique_id", 0) # noqa: B010 # static variable + setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 # static variable + return f"{prefix}_{unique_id}" def unique_var_name(): - return unique_name("__var") + return unique_name("_var") def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: dtype = dace.int64 - shape = [dace.symbol(unique_name(f"{name}_shp{i}"), dtype) for i in range(ndim)] - strides = [dace.symbol(unique_name(f"{name}_strd{i}"), dtype) for i in range(ndim)] + shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i in range(ndim)] return shape, strides From d7cf10fb31de4e60b33c25a6807e07605d5ecde0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 4 Dec 2023 09:19:33 -0500 Subject: [PATCH 52/67] DaCe 0.15 suffix state struct hotfix (#1382) --- src/gt4py/cartesian/backend/dace_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 11cd1fa895..b1e559a41e 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -451,7 +451,7 @@ class DaCeComputationCodegen: const int __I = domain[0]; const int __J = domain[1]; const int __K = domain[2]; - ${name}_${state_suffix} dace_handle; + ${name}${state_suffix} dace_handle; ${backend_specifics} auto allocator = gt::sid::cached_allocator(&${allocator}); ${"\\n".join(tmp_allocs)} @@ -566,7 +566,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S try: dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") except (KeyError, TypeError): - dace_state_suffix = "t" # old structure name + dace_state_suffix = "_t" # old structure name interface = cls.template.definition.render( name=sdfg.name, From 9f2ed1e41b50bd1d01a2a861999b5b44d6c9114b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 5 Dec 2023 08:07:57 +0100 Subject: [PATCH 53/67] feat[next]: Separates ITIR -> SDFG translation from running (#1379) Before it was only possible to translate ITIR to SDFG and execute it and it was not possible to extract the SDFG. This commits splits this task into two parts and thus allows to perform the ITIR to SDFG translation without executing it. --- .../runners/dace_iterator/__init__.py | 117 +++++++++++------- .../runners/dace_iterator/itir_to_sdfg.py | 10 ++ 2 files changed, 79 insertions(+), 48 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 735c6b6284..34ba2d2d95 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -185,58 +185,85 @@ def get_cache_id( return m.hexdigest() -def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Optional[dace.SDFG]: +def build_sdfg_from_itir( + program: itir.FencilDefinition, + *args, + offset_provider: dict[str, Any], + auto_optimize: bool = False, + on_gpu: bool = False, + column_axis: Optional[Dimension] = None, + lift_mode: LiftMode = LiftMode.FORCE_INLINE, +) -> dace.SDFG: + """Translate a Fencil into an SDFG. + + Args: + program: The Fencil that should be translated. + *args: Arguments for which the fencil should be called. + offset_provider: The set of offset providers that should be used. + auto_optimize: Apply DaCe's `auto_optimize` heuristic. + on_gpu: Performs the translation for GPU, defaults to `False`. + column_axis: The column axis to be used, defaults to `None`. + lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. + + Notes: + Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored. + """ + # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force + # `lift_more` to `FORCE_INLINE` mode. + lift_mode = LiftMode.FORCE_INLINE + + arg_types = [type_translation.from_value(arg) for arg in args] + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider, lift_mode) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + + return sdfg + + +def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters - auto_optimize = kwargs.get("auto_optimize", False) build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") - run_on_gpu = kwargs.get("run_on_gpu", False) - # Return parameter - return_sdfg = kwargs.get("return_sdfg", False) - run_sdfg = kwargs.get("run_sdfg", True) + on_gpu = kwargs.get("on_gpu", False) + auto_optimize = kwargs.get("auto_optimize", False) + lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) - lift_mode = ( - LiftMode.FORCE_INLINE - ) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + sdfg: Optional[dace.SDFG] = None if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg + else: - # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider, lift_mode) - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu) - sdfg = sdfg_genenerator.visit(program) - - # All arguments required by the SDFG, regardless if explicit and implicit, are added - # as positional arguments. In the front are all arguments to the Fencil, in that - # order, they are followed by the arguments created by the translation process, - # their order is determined by DaCe and unspecific. - assert len(sdfg.arg_names) == 0 - arg_list = [str(a) for a in program.params] - sig_list = sdfg.signature_arglist(with_types=False) - implicit_args = set(sig_list) - set(arg_list) - call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] - sdfg.arg_names = call_params - - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. - symbols: dict[str, int] = {} - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu) - - # compile SDFG and retrieve SDFG program + sdfg = build_sdfg_from_itir( + program, + *args, + offset_provider=offset_provider, + auto_optimize=auto_optimize, + on_gpu=on_gpu, + column_axis=column_axis, + lift_mode=lift_mode, + ) + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) @@ -271,16 +298,10 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Option if key in sdfg.signature_arglist(with_types=False) } - if run_sdfg: - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("frontend", "check_args", value=True) - sdfg_program(**expected_args) - # - - if return_sdfg: - return sdfg - return None + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + dace.config.Config.set("frontend", "check_args", value=True) + sdfg_program(**expected_args) def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: @@ -290,7 +311,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: **kwargs, build_cache=_build_cache_cpu, build_type=_build_type, - run_on_gpu=False, + on_gpu=False, ) @@ -308,7 +329,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: **kwargs, build_cache=_build_cache_gpu, build_type=_build_type, - run_on_gpu=True, + on_gpu=True, ) else: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 7a6f359771..b3e6662623 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -208,6 +208,16 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): access_node = last_state.add_access(inner_name) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) + # Create the call signature for the SDFG. + # All arguments required by the SDFG, regardless if explicit and implicit, are added + # as positional arguments. In the front are all arguments to the Fencil, in that + # order, they are followed by the arguments created by the translation process, + arg_list = [str(a) for a in node.params] + sig_list = program_sdfg.signature_arglist(with_types=False) + implicit_args = set(sig_list) - set(arg_list) + call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] + program_sdfg.arg_names = call_params + program_sdfg.validate() return program_sdfg From c547f536930c039f03a734ccb85f464fe4ba062d Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 5 Dec 2023 09:37:34 +0100 Subject: [PATCH 54/67] fix[next][dace]: Add check on neighbor index (#1383) Most part of this PR is about cleaning up and refactoring the code in DaCe backend for translation of neighbor reduction. As part of the refactoring, the backend is now using the reduce library node from DaCe library. However, this PR also contains one functional change, which is a fix. Neighbor reduction should check for validity of neighbor index. This means for neighbor-tables to check that the neighbor index stored in the table is not -1. For neighbor strided offsets, we should check that the neighbor index does not access the origin field out of boundary. --- .../runners/dace_iterator/itir_to_tasklet.py | 308 ++++++------------ .../ffront_tests/test_gt4py_builtins.py | 30 ++ 2 files changed, 137 insertions(+), 201 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 4fa5ae239c..f6f197859b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -20,7 +20,6 @@ import numpy as np from dace import subsets from dace.transformation.dataflow import MapFusion -from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen from gt4py import eve @@ -167,22 +166,19 @@ class Context: state: dace.SDFGState symbol_map: dict[str, TaskletExpr] # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_limit: int - reduce_wcr: Optional[str] + reduce_identity: Optional[SymbolExpr] def __init__( self, body: dace.SDFG, state: dace.SDFGState, symbol_map: dict[str, TaskletExpr], - reduce_limit: int = 0, - reduce_wcr: Optional[str] = None, + reduce_identity: Optional[SymbolExpr] = None, ): self.body = body self.state = state self.symbol_map = symbol_map - self.reduce_limit = reduce_limit - self.reduce_wcr = reduce_wcr + self.reduce_identity = reduce_identity def builtin_neighbors( @@ -193,42 +189,53 @@ def builtin_neighbors( offset_dim = offset_literal.value assert isinstance(offset_dim, str) iterator = transformer.visit(data) - table: NeighborTableOffsetProvider = transformer.offset_provider[offset_dim] - assert isinstance(table, NeighborTableOffsetProvider) - - offset = transformer.offset_provider[offset_dim] - if isinstance(offset, Dimension): + assert isinstance(iterator, IteratorExpr) + field_desc = iterator.field.desc(transformer.context.body) + + field_index = "__field_idx" + offset_provider = transformer.offset_provider[offset_dim] + if isinstance(offset_provider, NeighborTableOffsetProvider): + neighbor_check = f"{field_index} >= 0" + elif isinstance(offset_provider, StridedNeighborOffsetProvider): + neighbor_check = f"{field_index} < {field_desc.shape[offset_provider.neighbor_axis.value]}" + else: + assert isinstance(offset_provider, Dimension) raise NotImplementedError( "Neighbor reductions for cartesian grids not implemented in DaCe backend." ) + assert transformer.context.reduce_identity is not None + sdfg: dace.SDFG = transformer.context.body state: dace.SDFGState = transformer.context.state - shifted_dim = table.origin_axis.value + shifted_dim = offset_provider.origin_axis.value result_name = unique_var_name() - sdfg.add_array(result_name, dtype=iterator.dtype, shape=(table.max_neighbors,), transient=True) + sdfg.add_array( + result_name, dtype=iterator.dtype, shape=(offset_provider.max_neighbors,), transient=True + ) result_access = state.add_access(result_name) - table_name = connectivity_identifier(offset_dim) - # generate unique map index name to avoid conflict with other maps inside same state - index_name = unique_name("__neigh_idx") + neighbor_index = unique_name("neighbor_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", - ndrange={index_name: f"0:{table.max_neighbors}"}, + ndrange={neighbor_index: f"0:{offset_provider.max_neighbors}"}, ) + table_name = connectivity_identifier(offset_dim) + table_subset = (f"0:{sdfg.arrays[table_name].shape[0]}", neighbor_index) + shift_tasklet = state.add_tasklet( "shift", - code=f"__result = __table[__idx, {index_name}]", + code="__result = __table[__idx]", inputs={"__table", "__idx"}, outputs={"__result"}, ) data_access_tasklet = state.add_tasklet( "data_access", - code="__result = __field[__idx]", - inputs={"__field", "__idx"}, + code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", + inputs={"__field", field_index}, outputs={"__result"}, ) idx_name = unique_var_name() @@ -237,7 +244,7 @@ def builtin_neighbors( state.add_access(table_name), me, shift_tasklet, - memlet=create_memlet_full(table_name, sdfg.arrays[table_name]), + memlet=create_memlet_at(table_name, table_subset), dst_conn="__table", ) state.add_memlet_path( @@ -247,17 +254,11 @@ def builtin_neighbors( memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), dst_conn="__idx", ) - state.add_edge( - shift_tasklet, - "__result", - data_access_tasklet, - "__idx", - dace.Memlet.simple(idx_name, "0"), - ) + state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) # select full shape only in the neighbor-axis dimension field_subset = tuple( - f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" - for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) + f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) ) state.add_memlet_path( iterator.field, @@ -270,7 +271,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, index_name), + memlet=dace.Memlet.simple(result_name, neighbor_index), src_conn="__result", ) @@ -508,14 +509,16 @@ def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): raise NotImplementedError() def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr] + self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True ) -> tuple[ Context, list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], list[ValueExpr], ]: func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = filter_neighbor_tables(self.offset_provider) + neighbor_tables = ( + filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else [] + ) connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # Create the SDFG for the lambda's body @@ -557,13 +560,12 @@ def visit_Lambda( lambda_sdfg, lambda_state, lambda_symbols_pass.symbol_refs, - reduce_limit=self.context.reduce_limit, - reduce_wcr=self.context.reduce_wcr, + reduce_identity=self.context.reduce_identity, ) lambda_taskgen = PythonTaskletCodegen(self.offset_provider, lambda_context, self.node_types) results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lamda + # We are flattening the returned list of value expressions because the multiple outputs of a lambda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): @@ -573,8 +575,7 @@ def visit_Lambda( lambda_state.add_nedge( expr.value, result_access, - # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), + dace.Memlet.simple(result_access.data, "0"), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: @@ -700,60 +701,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: args: list[ValueExpr] sorted_dims = sorted(iterator.dimensions) - if self.context.reduce_limit: - # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing - result_name = unique_var_name() - self.context.body.add_array( - result_name, - dtype=iterator.dtype, - shape=(self.context.reduce_limit,), - transient=True, - ) - result_access = self.context.state.add_access(result_name) - - # generate unique map index name to avoid conflict with other maps inside same state - index_name = unique_name("__deref_idx") - me, mx = self.context.state.add_map( - "deref_map", - ndrange={index_name: f"0:{self.context.reduce_limit}"}, - ) - - # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - flat_index = [ - f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in sorted_dims - ] - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in iterator.indices - ] - internals = [f"{arg.value.data}_v" for arg in args] - - deref_tasklet = self.context.state.add_tasklet( - name="deref", - inputs=set(internals), - outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", - ) - - for arg, internal in zip(args, internals): - input_memlet = create_memlet_full( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_memlet_path( - arg.value, me, deref_tasklet, memlet=input_memlet, dst_conn=internal - ) - - self.context.state.add_memlet_path( - deref_tasklet, - mx, - result_access, - memlet=dace.Memlet.simple(result_name, index_name), - src_conn="__result", - ) - - return [ValueExpr(value=result_access, dtype=iterator.dtype)] - - elif all([dim in iterator.indices for dim in iterator.dimensions]): + if all([dim in iterator.indices for dim in iterator.dimensions]): # The deref iterator has index values on all dimensions: the result will be a scalar args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims @@ -930,8 +878,9 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): - result_name = unique_var_name() - result_access = self.context.state.add_access(result_name) + node_type = self.node_types[id(node)] + assert isinstance(node_type, itir_typing.Val) + reduce_dtype = itir_type_as_dace_type(node_type.dtype) if len(node.args) == 1: assert ( @@ -939,131 +888,70 @@ def _visit_reduce(self, node: itir.FunCall): and isinstance(node.args[0].fun, itir.SymRef) and node.args[0].fun.id == "neighbors" ) - args = self.visit(node.args) - assert len(args) == 1 - args = args[0] - assert len(args) == 1 - neighbors_expr = args[0] - result_dtype = neighbors_expr.dtype assert isinstance(node.fun, itir.FunCall) op_name = node.fun.args[0] assert isinstance(op_name, itir.SymRef) - init = node.fun.args[1] + reduce_identity = node.fun.args[1] + assert isinstance(reduce_identity, itir.Literal) - reduce_array_desc = neighbors_expr.value.desc(self.context.body) + # set reduction state + self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) + + args = self.visit(node.args) + + assert len(args) == 1 and len(args[0]) == 1 + reduce_input_node = args[0][0].value - self.context.body.add_scalar(result_name, result_dtype, transient=True) - op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") - reduce_tasklet = self.context.state.add_tasklet( - "reduce", - code=f"__result = {init}\nfor __idx in range({reduce_array_desc.shape[0]}):\n __result = {op_str}", - inputs={"__values"}, - outputs={"__result"}, - ) - self.context.state.add_edge( - args[0].value, - None, - reduce_tasklet, - "__values", - create_memlet_full(neighbors_expr.value.data, reduce_array_desc), - ) - self.context.state.add_edge( - reduce_tasklet, - "__result", - result_access, - None, - dace.Memlet.simple(result_name, "0"), - ) else: assert isinstance(node.fun, itir.FunCall) assert isinstance(node.fun.args[0], itir.Lambda) fun_node = node.fun.args[0] + assert isinstance(fun_node.expr, itir.FunCall) - args = [] - for node_arg in node.args: - if ( - isinstance(node_arg, itir.FunCall) - and isinstance(node_arg.fun, itir.SymRef) - and node_arg.fun.id == "neighbors" - ): - expr = self.visit(node_arg) - args.append(*expr) - else: - args.append(None) - - # first visit only arguments for neighbor selection, all other arguments are none - neighbor_args = [arg for arg in args if arg] - - # check that all neighbors expression have the same range - assert ( - len( - set([self.context.body.arrays[expr.value.data].shape for expr in neighbor_args]) - ) - == 1 - ) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] - nreduce_domain = {"__idx": f"0:{nreduce}"} + # set reduction state in visit context + self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - result_dtype = neighbor_args[0].dtype - self.context.body.add_scalar(result_name, result_dtype, transient=True) + args = flatten_list(self.visit(node.args)) - assert isinstance(fun_node.expr, itir.FunCall) - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) + # clear context + self.context.reduce_identity = None - # initialize the reduction result based on type of operation - init_value = get_reduce_identity_value(op_name.id, result_dtype) - init_state = self.context.body.add_state_before(self.context.state, "init", True) - init_tasklet = init_state.add_tasklet( - "init_reduce", {}, {"__out"}, f"__out = {init_value}" - ) - init_state.add_edge( - init_tasklet, - "__out", - init_state.add_access(result_name), - None, - dace.Memlet.simple(result_name, "0"), + # check that all neighbor expressions have the same shape + nreduce_shape = args[1].value.desc(self.context.body).shape + assert all( + [arg.value.desc(self.context.body).shape == nreduce_shape for arg in args[2:]] ) - # set reduction state to enable dereference of neighbors in input fields and to set WCR on reduce tasklet - self.context.reduce_limit = nreduce - self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( - "x", "y" - ) + nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape))) + nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)} - # visit child nodes for input arguments - for i, node_arg in enumerate(node.args): - if not args[i]: - args[i] = self.visit(node_arg)[0] + reduce_input_name = unique_var_name() + self.context.body.add_array( + reduce_input_name, nreduce_shape, reduce_dtype, transient=True + ) lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) - lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) - - # clear context - self.context.reduce_limit = 0 - self.context.reduce_wcr = None - - # the connectivity arrays (neighbor tables) are not needed inside the reduce lambda SDFG - neighbor_tables = filter_neighbor_tables(self.offset_provider) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) - lambda_context.body.remove_data(var) - # cleanup symbols previously used for shape and stride of connectivity arrays - p = RemoveUnusedSymbols() - p.apply_pass(lambda_context.body, {}) - - input_memlets = [ - dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) - ] - output_memlet = dace.Memlet.simple(result_name, "0") + lambda_context, inner_inputs, inner_outputs = self.visit( + lambda_node, args=args, use_neighbor_tables=False + ) - input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} - output_mapping = {inner_outputs[0].value.data: output_memlet} + input_mapping = { + param: create_memlet_at(arg.value.data, nreduce_index) + for (param, _), arg in zip(inner_inputs, args) + } + output_mapping = { + inner_outputs[0].value.data: create_memlet_at(reduce_input_name, nreduce_index) + } symbol_mapping = map_nested_sdfg_symbols( self.context.body, lambda_context.body, input_mapping ) + reduce_input_node = self.context.state.add_access(reduce_input_name) + nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( self.context.state, sdfg=lambda_context.body, @@ -1072,14 +960,32 @@ def _visit_reduce(self, node: itir.FunCall): outputs=output_mapping, symbol_mapping=symbol_mapping, input_nodes={arg.value.data: arg.value for arg in args}, - output_nodes={result_name: result_access}, + output_nodes={reduce_input_name: reduce_input_node}, ) - # we apply map fusion only to the nested-SDFG which is generated for the reduction operator - # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG - self.context.body.apply_transformations_repeated([MapFusion], validate=False) + reduce_input_desc = reduce_input_node.desc(self.context.body) + + result_name = unique_var_name() + # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node + self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) + result_access = self.context.state.add_access(result_name) + + reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") + reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) + self.context.state.add_nedge( + reduce_input_node, + reduce_node, + dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), + ) + self.context.state.add_nedge( + reduce_node, result_access, dace.Memlet.simple(result_name, "0") + ) + + # we apply map fusion only to the nested-SDFG which is generated for the reduction operator + # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG + self.context.body.apply_transformations_repeated([MapFusion], validate=False) - return [ValueExpr(result_access, result_dtype)] + return [ValueExpr(result_access, reduce_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index bbbac6c139..e8d0c8b163 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,6 +46,16 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -69,6 +79,16 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -82,6 +102,16 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) From 8e644585361aac30bf97f753e652117c0884bde5 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 12 Dec 2023 13:04:03 +0100 Subject: [PATCH 55/67] feat[next][dace]: Add support for if expressions with tuple argument (#1393) Some icon4py stencils require support for if expressions with tuple arguments. This PR adds support to the DaCe backend in the visitor of builtin_if function. Additionally, this PR contains one fix in the result of builtin_tuple_get, which should return a list. --- .../runners/dace_iterator/__init__.py | 1 - .../runners/dace_iterator/itir_to_tasklet.py | 42 ++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 34ba2d2d95..acfa06b456 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -247,7 +247,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) - sdfg: Optional[dace.SDFG] = None if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index f6f197859b..32b8cbf2b1 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -321,16 +321,36 @@ def builtin_can_deref( def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - args = [arg for li in transformer.visit(node_args) for arg in li] - expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args + args = transformer.visit(node_args) + assert len(args) == 3 + if_node = args[0][0] if isinstance(args[0], list) else args[0] + + # the argument could be a list of elements on each branch representing the result of `make_tuple` + # however, the normal case is to find one value expression + assert len(args[1]) == len(args[2]) + if_expr_args = [ + (a[0] if isinstance(a, list) else a, b[0] if isinstance(b, list) else b) + for a, b in zip(args[1], args[2]) ] - expr = "({1} if {0} else {2})".format(*internals) - node_type = transformer.node_types[id(node)] - assert isinstance(node_type, itir_typing.Val) - type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(expr_args, expr, type_, "if") + + # in case of tuple arguments, generate one if-tasklet for each element of the output tuple + if_expr_values = [] + for a, b in if_expr_args: + assert a.dtype == b.dtype + expr_args = [ + (arg, f"{arg.value.data}_v") + for arg in (if_node, a, b) + if not isinstance(arg, SymbolExpr) + ] + internals = [ + arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" + for arg in (if_node, a, b) + ] + expr = "({1} if {0} else {2})".format(*internals) + if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if") + if_expr_values.append(if_expr[0]) + + return if_expr_values def builtin_list_get( @@ -356,7 +376,7 @@ def builtin_list_get( def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - args = [transformer.visit(node_args[0])[0]] + args = transformer.visit(node_args[0]) internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] assert isinstance(target_type, itir.SymRef) @@ -380,7 +400,7 @@ def builtin_tuple_get( elements = transformer.visit(node_args[1]) index = node_args[0] if isinstance(index, itir.Literal): - return elements[int(index.value)] + return [elements[int(index.value)]] raise ValueError("Tuple can only be subscripted with compile-time constants") From a14ad09f6dd3043114238fc820d68621480cfc4e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 12 Dec 2023 13:24:51 +0100 Subject: [PATCH 56/67] feat[next]: Embedded field scan (#1365) Adds the scalar scan operator for embedded field view. --- .gitpod.yml | 2 +- src/gt4py/next/embedded/common.py | 17 ++ src/gt4py/next/embedded/context.py | 4 +- src/gt4py/next/embedded/nd_array_field.py | 8 +- src/gt4py/next/embedded/operators.py | 168 ++++++++++++++++++ src/gt4py/next/ffront/decorator.py | 95 ++++------ src/gt4py/next/field_utils.py | 22 +++ src/gt4py/next/iterator/embedded.py | 19 +- src/gt4py/next/utils.py | 22 ++- tests/next_tests/exclusion_matrices.py | 1 - tests/next_tests/integration_tests/cases.py | 6 +- .../ffront_tests/test_execution.py | 80 +++++++++ .../iterator_tests/test_column_stencil.py | 4 +- .../unit_tests/embedded_tests/test_common.py | 14 +- .../iterator_tests/test_embedded_internals.py | 8 +- 15 files changed, 372 insertions(+), 98 deletions(-) create mode 100644 src/gt4py/next/embedded/operators.py create mode 100644 src/gt4py/next/field_utils.py diff --git a/.gitpod.yml b/.gitpod.yml index 1d579d88eb..802d87796a 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -5,7 +5,7 @@ image: tasks: - name: Setup venv and dev tools init: | - ln -s /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode + ln -sfn /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode python -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools wheel diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d796189ab3..558730cb82 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -14,6 +14,10 @@ from __future__ import annotations +import functools +import itertools +import operator + from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -90,6 +94,19 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) +def intersect_domains(*domains: common.Domain) -> common.Domain: + return functools.reduce( + operator.and_, + domains, + common.Domain(dims=tuple(), ranges=tuple()), + ) + + +def iterate_domain(domain: common.Domain): + for i in itertools.product(*[list(r) for r in domain.ranges]): + yield tuple(zip(domain.dims, i)) + + def _expand_ellipsis( indices: common.RelativeIndexSequence, target_size: int ) -> tuple[common.IntIndex | slice, ...]: diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index 5fbdbc6f25..93942a5959 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -24,7 +24,7 @@ #: Column range used in column mode (`column_axis != None`) in the current embedded iterator #: closure execution context. -closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range") +closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range") _undefined_offset_provider: common.OffsetProvider = {} @@ -37,7 +37,7 @@ @contextlib.contextmanager def new_context( *, - closure_column_range: range | eve.NothingType = eve.NOTHING, + closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING, offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, ): import gt4py.next.embedded.context as this_module diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ff6a2ceac7..6b69e8f8cc 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -16,7 +16,6 @@ import dataclasses import functools -import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import ClassVar @@ -49,11 +48,10 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = functools.reduce( - operator.and_, - [f.domain for f in fields if common.is_field(f)], - common.Domain(dims=tuple(), ranges=tuple()), + domain_intersection = embedded_common.intersect_domains( + *[f.domain for f in fields if common.is_field(f)] ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] for f in fields: if common.is_field(f): diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py new file mode 100644 index 0000000000..f50ace7687 --- /dev/null +++ b/src/gt4py/next/embedded/operators.py @@ -0,0 +1,168 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar + +from gt4py import eve +from gt4py._core import definitions as core_defs +from gt4py.next import common, constructors, utils +from gt4py.next.embedded import common as embedded_common, context as embedded_context + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@dataclasses.dataclass(frozen=True) +class EmbeddedOperator(Generic[_R, _P]): + fun: Callable[_P, _R] + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + return self.fun(*args, **kwargs) + + +@dataclasses.dataclass(frozen=True) +class ScanOperator(EmbeddedOperator[_R, _P]): + forward: bool + init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] + axis: common.Dimension + + def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun + scan_range = embedded_context.closure_column_range.get() + assert self.axis == scan_range[0] + scan_axis = scan_range[0] + domain_intersection = _intersect_scan_args(*args, *kwargs.values()) + non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) + + out_domain = common.Domain( + *[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection] + ) + if scan_axis not in out_domain.dims: + # even if the scan dimension is not in the input, we can scan over it + out_domain = common.Domain(*out_domain, (scan_range)) + + res = _construct_scan_array(out_domain)(self.init) + + def scan_loop(hpos): + acc = self.init + for k in scan_range[1] if self.forward else reversed(scan_range[1]): + pos = (*hpos, (scan_axis, k)) + new_args = [_tuple_at(pos, arg) for arg in args] + new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} + acc = self.fun(acc, *new_args, **new_kwargs) + _tuple_assign_value(pos, res, acc) + + if len(non_scan_domain) == 0: + # if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop + scan_loop(()) + else: + for hpos in embedded_common.iterate_domain(non_scan_domain): + scan_loop(hpos) + + return res + + +def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): + if "out" in kwargs: + # called from program or direct field_operator as program + offset_provider = kwargs.pop("offset_provider", None) + + new_context_kwargs = {} + if embedded_context.within_context(): + # called from program + assert offset_provider is None + else: + # field_operator as program + new_context_kwargs["offset_provider"] = offset_provider + + out = kwargs.pop("out") + domain = kwargs.pop("domain", None) + + flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) + assert all(f.domain == flattened_out[0].domain for f in flattened_out) + + out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain + + new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) + + with embedded_context.new_context(**new_context_kwargs) as ctx: + res = ctx.run(op, *args, **kwargs) + _tuple_assign_field( + out, + res, + domain=out_domain, + ) + else: + # called from other field_operator + return op(*args, **kwargs) + + +def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] + assert len(vertical_dim_filtered) <= 1 + return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING + + +def _tuple_assign_field( + target: tuple[common.MutableField | tuple, ...] | common.MutableField, + source: tuple[common.Field | tuple, ...] | common.Field, + domain: common.Domain, +): + @utils.tree_map + def impl(target: common.MutableField, source: common.Field): + target[domain] = source[domain] + + impl(target, source) + + +def _intersect_scan_args( + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] +) -> common.Domain: + return embedded_common.intersect_domains( + *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] + ) + + +def _construct_scan_array(domain: common.Domain): + @utils.tree_map + def impl(init: core_defs.Scalar) -> common.Field: + return constructors.empty(domain, dtype=type(init)) + + return impl + + +def _tuple_assign_value( + pos: Sequence[common.NamedIndex], + target: common.MutableField | tuple[common.MutableField | tuple, ...], + source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...], +) -> None: + @utils.tree_map + def impl(target: common.MutableField, source: core_defs.Scalar): + target[pos] = source + + impl(target, source) + + +def _tuple_at( + pos: Sequence[common.NamedIndex], + field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...], +) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: + @utils.tree_map + def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: + res = field[pos] if common.is_field(field) else field + assert core_defs.is_scalar_type(res) + return res + + return impl(field) # type: ignore[return-value] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e06c651b13..8202cda6f5 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,8 +32,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, common, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded from gt4py.next.common import Dimension, DimensionKind, GridType +from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, @@ -550,6 +551,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): definition: Optional[types.FunctionType] = None backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None + operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field(default_factory=dict) @classmethod @@ -586,6 +588,7 @@ def from_function( definition=definition, backend=backend, grid_type=grid_type, + operator_attributes=operator_attributes, ) def __gt_type__(self) -> ts.CallableType: @@ -692,68 +695,38 @@ def __call__( *args, **kwargs, ) -> None: - # TODO(havogt): Don't select mode based on existence of kwargs, - # because now we cannot provide nice error messages. E.g. set context var - # if we are reaching this from a program call. - if "out" in kwargs: - out = kwargs.pop("out") + if not next_embedded.context.within_context() and self.backend is not None: + # non embedded execution offset_provider = kwargs.pop("offset_provider", None) - if self.backend is not None: - # "out" and "offset_provider" -> field_operator as program - # When backend is None, we are in embedded execution and for now - # we disable the program generation since it would involve generating - # Python source code from a PAST node. - args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - arg_types = [] - for arg in args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - return self.as_program(arg_types, kwarg_types)( - *args, out, offset_provider=offset_provider, **kwargs - ) - else: - # "out" -> field_operator called from program in embedded execution or - # field_operator called directly from Python in embedded execution - domain = kwargs.pop("domain", None) - if not next_embedded.context.within_context(): - # field_operator from Python in embedded execution - with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: - res = ctx.run(self.definition, *args, **kwargs) - else: - # field_operator from program in embedded execution (offset_provicer is already set) - assert ( - offset_provider is None - or next_embedded.context.offset_provider.get() is offset_provider - ) - res = self.definition(*args, **kwargs) - _tuple_assign_field( - out, res, domain=None if domain is None else common.domain(domain) - ) - return + out = kwargs.pop("out") + args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) + # TODO(tehrengruber): check all offset providers are given + # deduce argument types + arg_types = [] + for arg in args: + arg_types.append(type_translation.from_value(arg)) + kwarg_types = {} + for name, arg in kwargs.items(): + kwarg_types[name] = type_translation.from_value(arg) + + return self.as_program(arg_types, kwarg_types)( + *args, out, offset_provider=offset_provider, **kwargs + ) else: - # field_operator called from other field_operator in embedded execution - assert self.backend is None - return self.definition(*args, **kwargs) - - -def _tuple_assign_field( - target: tuple[common.Field | tuple, ...] | common.Field, - source: tuple[common.Field | tuple, ...] | common.Field, - domain: Optional[common.Domain], -): - if isinstance(target, tuple): - if not isinstance(source, tuple): - raise RuntimeError(f"Cannot assign {source} to {target}.") - for t, s in zip(target, source): - _tuple_assign_field(t, s, domain) - else: - domain = domain or target.domain - target[domain] = source[domain] + if self.operator_attributes is not None and any( + has_scan_op_attribute := [ + attribute in self.operator_attributes + for attribute in ["init", "axis", "forward"] + ] + ): + assert all(has_scan_op_attribute) + forward = self.operator_attributes["forward"] + init = self.operator_attributes["init"] + axis = self.operator_attributes["axis"] + op = embedded_operators.ScanOperator(self.definition, forward, init, axis) + else: + op = embedded_operators.EmbeddedOperator(self.definition) + return embedded_operators.field_operator_call(op, args, kwargs) @typing.overload diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py new file mode 100644 index 0000000000..14b7c3838c --- /dev/null +++ b/src/gt4py/next/field_utils.py @@ -0,0 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +from gt4py.next import common, utils + + +@utils.tree_map +def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: + return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b02d6c8d72..b00e53bfd9 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -196,7 +196,7 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: #: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range +column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range #: Offset provider dict in the current closure execution context. offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider @@ -211,8 +211,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 - column_range = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range), data) + column_range: common.NamedRange = column_range_cvar.get() + self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -746,7 +746,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] assert column_range is not None col: list[ @@ -823,7 +823,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range)) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -864,7 +864,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1479,7 +1479,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1532,7 +1532,10 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = column.col_range + column_range = ( + column_axis, + common.UnitRange(column.col_range.start, column.col_range.stop), + ) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index baae8361c5..ec459906e0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -15,10 +15,6 @@ import functools from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast -import numpy as np - -from gt4py.next import common - class RecursionGuard: """ @@ -57,7 +53,6 @@ def __exit__(self, *exc): _T = TypeVar("_T") - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -66,8 +61,17 @@ def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) +# TODO(havogt): remove flatten duplications in the whole codebase +def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]: + if isinstance(value, tuple): + return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting + else: + return (value,) + + def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: - """Apply `fun` to each entry of (possibly nested) tuples. + """ + Apply `fun` to each entry of (possibly nested) tuples. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) @@ -88,9 +92,3 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - - -# TODO(havogt): consider moving to module like `field_utils` -@tree_map -def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: - return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 84287e209f..3c42a180dd 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -130,7 +130,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 81f216397b..b1e26b40cb 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,7 +28,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors, utils +from gt4py.next import common, constructors, field_utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -436,8 +436,8 @@ def verify( out_comp = out or inout assert out_comp is not None - out_comp_ndarray = utils.asnumpy(out_comp) - ref_ndarray = utils.asnumpy(ref) + out_comp_ndarray = field_utils.asnumpy(out_comp) + ref_ndarray = field_utils.asnumpy(ref) assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7f37b41383..51f853d41d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -292,6 +292,7 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_cartesian_shift @pytest.mark.uses_scan @pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures @@ -802,6 +803,85 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: cases.verify(cartesian_case, simple_scan_operator, (inp1, inp2), out=out, ref=expected) +@pytest.mark.uses_scan +def test_scan_different_domain_in_tuple(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp1_np = np.ones( + ( + i_size + 1, + k_size, + ) + ) # i_size bigger than in the other argument + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp1 = cartesian_case.as_field([IDim, KDim], inp1_np) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo( + inp1: gtx.Field[[IDim, KDim], float], inp2: gtx.Field[[IDim, KDim], float] + ) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, inp1, inp2, out=out, ref=expected) + + +@pytest.mark.uses_scan +def test_scan_tuple_field_scalar_mixed(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + 1.0 + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo(inp1: float, inp2: gtx.Field[[IDim, KDim], float]) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, 1.0, inp2, out=out, ref=expected) + + def test_docstring(cartesian_case): @gtx.field_operator def fieldop_with_docstring(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index fd571514ac..9ba8eef3a3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import utils +from gt4py.next import field_utils from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset @@ -158,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct k_size = 5 inp = inp_function(k_size) - ref = ref_function(utils.asnumpy(inp)) + ref = ref_function(field_utils.asnumpy(inp)) out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 640ed326bb..de511fdabb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -19,7 +19,7 @@ from gt4py.next import common from gt4py.next.common import UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions -from gt4py.next.embedded.common import _slice_range, sub_domain +from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain @pytest.mark.parametrize( @@ -135,3 +135,15 @@ def test_sub_domain(domain, index, expected): expected = common.domain(expected) result = sub_domain(domain, index) assert result == expected + + +def test_iterate_domain(): + domain = common.domain({I: 2, J: 3}) + ref = [] + for i in domain[I][1]: + for j in domain[J][1]: + ref.append(((I, i), (J, j))) + + testee = list(iterate_domain(domain)) + + assert testee == ref diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index 3a35570ca2..9238cd4f7a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -19,13 +19,14 @@ import numpy as np import pytest +from gt4py.next import common from gt4py.next.iterator import embedded def _run_within_context( func: Callable[[], Any], *, - column_range: Optional[range] = None, + column_range: Optional[common.NamedRange] = None, offset_provider: Optional[embedded.OffsetProvider] = None, ) -> Any: def wrapped_func(): @@ -59,7 +60,10 @@ def test_func(data_a: int, data_b: int): # Setting an invalid column_range here shouldn't affect other contexts embedded.column_range_cvar.set(range(2, 999)) - _run_within_context(lambda: test_func(2, 3), column_range=range(0, 3)) + _run_within_context( + lambda: test_func(2, 3), + column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)), + ) def test_column_ufunc_with_scalar(): From 3f595ffd6206b5bf3344b7288f98ac8e82adba52 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 12 Dec 2023 17:29:32 +0100 Subject: [PATCH 57/67] feat[next][dace]: Fix for broken DaCe test (#1396) Fix for broken DaCe test in baseline: - use `flatten_list` to get `ValueExpr` arguments to numeric builtin function Additionally, enable test for DaCe backend (left-over from PR #1393). --- .../runners/dace_iterator/itir_to_tasklet.py | 4 +--- .../feature_tests/iterator_tests/test_conditional.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 32b8cbf2b1..d10a14a1ee 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -1010,9 +1010,7 @@ def _visit_reduce(self, node: itir.FunCall): def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args: list[SymbolExpr | ValueExpr] = list( - itertools.chain(*[self.visit(arg) for arg in node.args]) - ) + args = flatten_list(self.visit(node.args)) expr_args = [ (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) ] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 8536dbea90..db7776b2f4 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -31,7 +31,6 @@ def stencil_conditional(inp): return tuple_get(0, tmp) + tuple_get(1, tmp) -@pytest.mark.uses_tuple_returns def test_conditional_w_tuple(program_processor): program_processor, validate = program_processor From a5b2450e282add00fe90b8cf98cd68d96d42b1ea Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Wed, 13 Dec 2023 11:41:16 +0100 Subject: [PATCH 58/67] style[next]: standardize error messages. (#1386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - add style guide to the coding guidelines - fix existing error messages in next - deal with ensuing qa errorrs / test updates - unshadow one test and fix the code it wasn't testing Co-authored-by: Rico Häuselmann Co-authored-by: Enrique González Paredes --- CODING_GUIDELINES.md | 38 +++++ src/gt4py/_core/definitions.py | 22 ++- src/gt4py/next/allocators.py | 8 +- src/gt4py/next/common.py | 91 ++++++------ src/gt4py/next/constructors.py | 15 +- src/gt4py/next/embedded/common.py | 6 +- src/gt4py/next/embedded/nd_array_field.py | 37 ++--- src/gt4py/next/errors/exceptions.py | 10 +- .../next/ffront/ast_passes/simple_assign.py | 2 +- .../ffront/ast_passes/single_static_assign.py | 2 +- src/gt4py/next/ffront/decorator.py | 29 ++-- src/gt4py/next/ffront/fbuiltins.py | 7 +- src/gt4py/next/ffront/foast_introspection.py | 2 +- .../foast_passes/closure_var_folding.py | 2 +- .../ffront/foast_passes/type_deduction.py | 133 +++++++++--------- src/gt4py/next/ffront/foast_pretty_printer.py | 2 +- src/gt4py/next/ffront/foast_to_itir.py | 10 +- src/gt4py/next/ffront/func_to_foast.py | 23 +-- src/gt4py/next/ffront/func_to_past.py | 4 +- .../next/ffront/past_passes/type_deduction.py | 42 +++--- src/gt4py/next/ffront/past_to_itir.py | 36 ++--- src/gt4py/next/ffront/source_utils.py | 8 +- src/gt4py/next/ffront/type_info.py | 2 +- src/gt4py/next/iterator/dispatcher.py | 2 +- src/gt4py/next/iterator/embedded.py | 26 ++-- src/gt4py/next/iterator/ir.py | 8 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- src/gt4py/next/iterator/runtime.py | 2 +- src/gt4py/next/iterator/tracing.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 4 +- .../next/iterator/transforms/pass_manager.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 6 +- src/gt4py/next/iterator/type_inference.py | 29 ++-- src/gt4py/next/otf/binding/nanobind.py | 2 +- .../compilation/build_systems/cmake_lists.py | 4 +- src/gt4py/next/otf/compilation/compiler.py | 2 +- src/gt4py/next/otf/stages.py | 2 +- src/gt4py/next/otf/workflow.py | 6 +- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 6 +- .../codegens/gtfn/gtfn_module.py | 6 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 20 +-- .../program_processors/processor_interface.py | 88 ++++++++---- .../runners/dace_iterator/__init__.py | 4 +- .../runners/dace_iterator/itir_to_tasklet.py | 4 +- .../runners/dace_iterator/utility.py | 2 +- .../next/program_processors/runners/gtfn.py | 6 +- src/gt4py/next/type_system/type_info.py | 48 ++++--- .../next/type_system/type_translation.py | 34 ++--- tests/next_tests/integration_tests/cases.py | 16 +-- .../ffront_tests/ffront_test_utils.py | 3 +- .../ffront_tests/test_arg_call_interface.py | 8 +- .../ffront_tests/test_execution.py | 8 +- .../test_math_builtin_execution.py | 2 +- .../ffront_tests/test_math_unary_builtins.py | 4 +- .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_scalar_if.py | 4 +- .../ffront_tests/test_type_deduction.py | 68 ++++----- .../iterator_tests/test_builtins.py | 2 +- tests/next_tests/unit_tests/conftest.py | 2 +- .../embedded_tests/test_nd_array_field.py | 2 +- .../ffront_tests/test_func_to_foast.py | 14 +- .../ffront_tests/test_func_to_past.py | 18 +-- .../ffront_tests/test_past_to_itir.py | 4 +- .../iterator_tests/test_runtime_domain.py | 2 +- .../test_processor_interface.py | 4 +- .../next_tests/unit_tests/test_allocators.py | 2 +- tests/next_tests/unit_tests/test_common.py | 2 +- .../unit_tests/test_constructors.py | 4 +- .../test_type_translation.py | 2 +- 69 files changed, 571 insertions(+), 458 deletions(-) diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md index 957df0fb04..9376644064 100644 --- a/CODING_GUIDELINES.md +++ b/CODING_GUIDELINES.md @@ -51,6 +51,44 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the - Client code (like tests, doctests and examples) should use the above style for public FieldView API - Library code should always import the defining module and use qualified names. +### Error messages + +Error messages should be written as sentences, starting with a capital letter and ending with a period (avoid exclamation marks). Try to be informative without being verbose. Code objects such as 'ClassNames' and 'function_names' should be enclosed in single quotes, and so should string values used for message interpolation. + +Examples: + +```python +raise ValueError(f"Invalid argument 'dimension': should be of type 'Dimension', got '{dimension.type}'.") +``` + +Interpolated integer values do not need double quotes, if they are indicating an amount. Example: + +```python +raise ValueError(f"Invalid number of arguments: expected 3 arguments, got {len(args)}.") +``` + +The double quotes can also be dropped when presenting a sequence of values. In this case the message should be rephrased so the sequence is separated from the text by a colon ':'. + +```python +raise ValueError(f"unexpected keyword arguments: {', '.join(set(kwarg_names} - set(expected_kwarg_names)))}.") +``` + +The message should be kept to one sentence if reasonably possible. Ideally the sentence should be kept short and avoid unneccessary words. Examples: + +```python +# too many sentences +raise ValueError(f"Received an unexpeted number of arguments. Should receive 5 arguments, but got {len(args)}. Please provide the correct number of arguments.") +# better +raise ValueError(f"Wrong number of arguments: expected 5, got {len(args)}.") + +# less extreme +raise TypeError(f"Wrong argument type. Can only accept 'int's, got '{type(arg)}' instead.") +# but can still be improved +raise TypeError(f"Wrong argument type: 'int' expected, got '{type(arg)}'") +``` + +The terseness vs. helpfulness tradeoff should be more in favor of terseness for internal error messages and more in favor of helpfulness for `DSLError` and it's subclassses, where additional sentences are encouraged if they point out likely hidden sources of the problem or common fixes. + ### Docstrings We generate the API documentation automatically from the docstrings using [Sphinx][sphinx] and some extensions such as [Sphinx-autodoc][sphinx-autodoc] and [Sphinx-napoleon][sphinx-napoleon]. These follow the Google Python Style Guide docstring conventions to automatically format the generated documentation. A complete overview can be found here: [Example Google Style Python Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google). diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 0e6301ae0f..091fa77e3f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -73,17 +73,23 @@ BoolScalar: TypeAlias = Union[bool_, bool] BoolT = TypeVar("BoolT", bound=BoolScalar) -BOOL_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], BoolScalar.__args__) # type: ignore[attr-defined] +BOOL_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], BoolScalar.__args__ # type: ignore[attr-defined] +) IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] IntT = TypeVar("IntT", bound=IntScalar) -INT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], IntScalar.__args__) # type: ignore[attr-defined] +INT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], IntScalar.__args__ # type: ignore[attr-defined] +) UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) -UINT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], UnsignedIntScalar.__args__) # type: ignore[attr-defined] +UINT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], UnsignedIntScalar.__args__ # type: ignore[attr-defined] +) IntegralScalar: TypeAlias = Union[IntScalar, UnsignedIntScalar] @@ -93,7 +99,9 @@ FloatingScalar: TypeAlias = Union[float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) -FLOAT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], FloatingScalar.__args__) # type: ignore[attr-defined] +FLOAT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], FloatingScalar.__args__ # type: ignore[attr-defined] +) #: Type alias for all scalar types supported by GT4Py @@ -195,7 +203,7 @@ def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: if issubclass(sc_type, numbers.Complex): return DTypeKind.COMPLEX - raise TypeError("Unknown scalar type kind") + raise TypeError("Unknown scalar type kind.") @dataclasses.dataclass(frozen=True) @@ -491,10 +499,10 @@ def __rtruediv__(self, other: Any) -> NDArrayObject: def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... - def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... def __gt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 58600d8cda..97e83276fe 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -142,7 +142,9 @@ def get_allocator( elif not strict or is_field_allocator(default): return default else: - raise TypeError(f"Object {obj} is neither a field allocator nor a field allocator factory") + raise TypeError( + f"Object '{obj}' is neither a field allocator nor a field allocator factory." + ) @dataclasses.dataclass(frozen=True) @@ -331,7 +333,7 @@ def allocate( """ if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified") + raise ValueError("No 'device' or 'allocator' specified.") actual_allocator = get_allocator(allocator) if actual_allocator is None: assert device is not None # for mypy @@ -339,7 +341,7 @@ def allocate( elif device is None: device = core_defs.Device(actual_allocator.__gt_device_type__, 0) elif device.device_type != actual_allocator.__gt_device_type__: - raise ValueError(f"Device {device} and allocator {actual_allocator} are incompatible") + raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") return actual_allocator.__gt_allocate__( domain=common.domain(domain), diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 7f1ad8c0bb..3e1fe52f31 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -125,7 +125,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: - raise ValueError("UnitRange: step required to be `1`.") + raise ValueError("'UnitRange': step required to be '1'.") new_start = self.start + (start or 0) new_stop = (self.start if stop > 0 else self.stop) + stop return UnitRange(new_start, new_stop) @@ -136,7 +136,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re if 0 <= index < len(self): return self.start + index else: - raise IndexError("UnitRange index out of range") + raise IndexError("'UnitRange' index out of range") def __and__(self, other: Set[int]) -> UnitRange: if isinstance(other, UnitRange): @@ -144,7 +144,9 @@ def __and__(self, other: Set[int]) -> UnitRange: stop = min(self.stop, other.stop) return UnitRange(start, stop) else: - raise NotImplementedError("Can only find the intersection between UnitRange instances.") + raise NotImplementedError( + "Can only find the intersection between 'UnitRange' instances." + ) def __le__(self, other: Set[int]): if isinstance(other, UnitRange): @@ -167,7 +169,7 @@ def __add__(self, other: int | Set[int]) -> UnitRange: ) ) else: - raise NotImplementedError("Can only compute union with int instances.") + raise NotImplementedError("Can only compute union with 'int' instances.") def __sub__(self, other: int | Set[int]) -> UnitRange: if isinstance(other, int): @@ -178,7 +180,7 @@ def __sub__(self, other: int | Set[int]) -> UnitRange: else: return self + (-other) else: - raise NotImplementedError("Can only compute substraction with int instances.") + raise NotImplementedError("Can only compute substraction with 'int' instances.") __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented @@ -199,7 +201,7 @@ def unit_range(r: RangeLike) -> UnitRange: return r if isinstance(r, range): if r.step != 1: - raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") + raise ValueError(f"'UnitRange' requires step size 1, got '{r.step}'.") return UnitRange(r.start, r.stop) # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) # once the related mypy bug (#16358) gets fixed @@ -211,7 +213,7 @@ def unit_range(r: RangeLike) -> UnitRange: return UnitRange(r[0], r[1]) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) - raise ValueError(f"`{r!r}` cannot be interpreted as `UnitRange`.") + raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar @@ -296,20 +298,20 @@ def __init__( ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: - raise ValueError("Either both none of `dims` and `ranges` must be specified.") + raise ValueError("Either both none of 'dims' and 'ranges' must be specified.") if len(args) > 0: raise ValueError( - "No extra `args` allowed when constructing fomr `dims` and `ranges`." + "No extra 'args' allowed when constructing fomr 'dims' and 'ranges'." ) assert dims is not None and ranges is not None # for mypy if not all(isinstance(dim, Dimension) for dim in dims): raise ValueError( - f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + f"'dims' argument needs to be a 'tuple[Dimension, ...]', got '{dims}'." ) if not all(isinstance(rng, UnitRange) for rng in ranges): raise ValueError( - f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + f"'ranges' argument needs to be a 'tuple[UnitRange, ...]', got '{ranges}'." ) if len(dims) != len(ranges): raise ValueError( @@ -320,13 +322,15 @@ def __init__( object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): - raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") + raise ValueError( + f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + ) dims, ranges = zip(*args) if args else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) if len(set(self.dims)) != len(self.dims): - raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") + raise NotImplementedError(f"Domain dimensions must be unique, not '{self.dims}'.") def __len__(self) -> int: return len(self.ranges) @@ -365,7 +369,7 @@ def __getitem__( # noqa: F811 # redefine unused index_pos = self.dims.index(index) return self.dims[index_pos], self.ranges[index_pos] except ValueError: - raise KeyError(f"No Dimension of type {index} is present in the Domain.") + raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") @@ -415,10 +419,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: if isinstance(index, Dimension): dim_index = self.dim_index(index) if dim_index is None: - raise ValueError(f"Dimension {index} not found in Domain.") + raise ValueError(f"Dimension '{index}' not found in Domain.") index = dim_index if not (-len(self.dims) <= index < len(self.dims)): - raise IndexError(f"Index {index} out of bounds for Domain of length {len(self.dims)}.") + raise IndexError( + f"Index '{index}' out of bounds for Domain of length {len(self.dims)}." + ) if index < 0: index += len(self.dims) new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) @@ -462,13 +468,16 @@ def domain(domain_like: DomainLike) -> Domain: if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): return Domain( dims=tuple(domain_like.keys()), - ranges=tuple(UnitRange(0, s) for s in domain_like.values()), # type: ignore[arg-type] # type of `s` is checked in condition + ranges=tuple( + UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition + for s in domain_like.values() + ), ) return Domain( dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), ) - raise ValueError(f"`{domain_like}` is not `DomainLike`.") + raise ValueError(f"'{domain_like}' is not 'DomainLike'.") def _broadcast_ranges( @@ -670,7 +679,8 @@ class ConnectivityKind(enum.Flag): @extended_runtime_checkable -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): # type: ignore[misc] # DimT should be covariant, but break in another place +# type: ignore[misc] # DimT should be covariant, but break in another place +class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod def codomain(self) -> DimT: @@ -690,61 +700,61 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa # Operators def __abs__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def is_connectivity_field( @@ -845,7 +855,7 @@ def __gt_dims__(self) -> tuple[Dimension, ...]: @property def __gt_origin__(self) -> Never: - raise TypeError("CartesianConnectivity does not support this operation") + raise TypeError("'CartesianConnectivity' does not support this operation.") @property def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: @@ -877,7 +887,7 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa if not isinstance(image_range, UnitRange): if image_range[0] != self.codomain: raise ValueError( - f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range[1] @@ -1017,3 +1027,4 @@ def register_builtin_func( @classmethod def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]: return cls._builtin_func_map.get(func, NotImplemented) + return cls._builtin_func_map.get(func, NotImplemented) diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 63fde1cfde..9bb4cf17e5 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -254,12 +254,12 @@ def as_field( domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: raise ValueError( - f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) if origin: domain_dims = set(domain) if unknown_dims := set(origin.keys()) - domain_dims: - raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}") + raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}.") else: origin = {} actual_domain = common.domain( @@ -277,7 +277,7 @@ def as_field( # already the correct layout and device. shape = storage_utils.asarray(data).shape if shape != actual_domain.shape: - raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) @@ -334,20 +334,20 @@ def as_connectivity( domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: raise ValueError( - f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)]) else: actual_domain = common.domain(cast(common.DomainLike, domain)) if not isinstance(codomain, common.Dimension): - raise ValueError(f"Invalid codomain dimension `{codomain}`") + raise ValueError(f"Invalid codomain dimension '{codomain}'.") # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has # already the correct layout and device. shape = storage_utils.asarray(data).shape if shape != actual_domain.shape: - raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) @@ -356,7 +356,8 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) - buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] # TODO(havogt): consider addin MutableNDArrayObject + # TODO(havogt): consider addin MutableNDArrayObject + buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common.connectivity( buffer.ndarray, codomain=codomain, domain=actual_domain ) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 558730cb82..87e0800a10 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -32,7 +32,7 @@ def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Doma if common.is_relative_index_sequence(index_sequence): return _relative_sub_domain(domain, index_sequence) - raise IndexError(f"Unsupported index type: {index}") + raise IndexError(f"Unsupported index type: '{index}'.") def _relative_sub_domain( @@ -42,7 +42,9 @@ def _relative_sub_domain( expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): - raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + raise IndexError( + f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." + ) expanded += (slice(None),) * (len(domain) - len(expanded)) for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 6b69e8f8cc..fbfe64ac42 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -174,7 +174,7 @@ def remap( dim = connectivity.codomain dim_idx = self.domain.dim_index(dim) if dim_idx is None: - raise ValueError(f"Incompatible index field, expected a field with dimension {dim}.") + raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") current_range: common.UnitRange = self.domain[dim_idx][1] new_ranges = connectivity.inverse_image(current_range) @@ -226,7 +226,7 @@ def __setitem__( if common.is_field(value): if not value.domain == target_domain: raise ValueError( - f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'." ) value = value.ndarray @@ -268,28 +268,28 @@ def __setitem__( def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_and", "logical_and")(self, other) - raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__and__' not implemented for non-'bool' fields.") __rand__ = __and__ def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_or", "logical_or")(self, other) - raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__or__' not implemented for non-'bool' fields.") __ror__ = __or__ def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_xor", "logical_xor")(self, other) - raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__xor__' not implemented for non-'bool' fields.") __rxor__ = __xor__ def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("invert", "invert")(self) - raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__invert__' not implemented for non-'bool' fields.") def _slice( self, index: common.AnyIndexSpec @@ -322,7 +322,8 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig raise NotImplementedError() @property - def codomain(self) -> common.DimT: # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + def codomain(self) -> common.DimT: return self._codomain @functools.cached_property @@ -378,7 +379,7 @@ def inverse_image( ): # TODO(havogt): cleanup duplication with CartesianConnectivity if image_range[0] != self.codomain: raise ValueError( - f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range[1] @@ -423,7 +424,7 @@ def inverse_image( if non_contiguous_dims: raise ValueError( - f"Restriction generates non-contiguous dimensions {non_contiguous_dims}" + f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'." ) return new_dims @@ -446,8 +447,12 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ # -- Specialized implementations for builtin operations on array fields -- -NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] -NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] +NdArrayField.register_builtin_func( + fbuiltins.abs, NdArrayField.__abs__ # type: ignore[attr-defined] +) +NdArrayField.register_builtin_func( + fbuiltins.power, NdArrayField.__pow__ # type: ignore[attr-defined] +) # TODO gamma for name in ( @@ -480,7 +485,7 @@ def _builtin_op( if not axis.kind == common.DimensionKind.LOCAL: raise ValueError("Can only reduce local dimensions.") if axis not in field.domain.dims: - raise ValueError(f"Field doesn't have dimension {axis}. Cannot reduce.") + raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") reduce_dim_index = field.domain.dims.index(axis) new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) return field.__class__.from_array( @@ -547,7 +552,7 @@ def __setitem__( value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` - raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") + raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") common.field.register(jnp.ndarray, JaxArrayField.from_array) @@ -572,7 +577,7 @@ def _builtins_broadcast( ) -> common.Field: # separated for typing reasons if common.is_field(field): return _broadcast(field, new_dimensions) - raise AssertionError("Scalar case not reachable from `fbuiltins.broadcast`.") + raise AssertionError("Scalar case not reachable from 'fbuiltins.broadcast'.") NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) @@ -581,7 +586,7 @@ def _builtins_broadcast( def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdArrayField: if isinstance(field, NdArrayField): return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) - raise AssertionError("This is the NdArrayField implementation of `fbuiltins.astype`.") + raise AssertionError("This is the NdArrayField implementation of 'fbuiltins.astype'.") NdArrayField.register_builtin_func(fbuiltins.astype, _astype) @@ -643,4 +648,4 @@ def _compute_slice( elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: - raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") + raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index e956858549..081453c023 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -61,7 +61,7 @@ class UnsupportedPythonFeatureError(DSLError): feature: str def __init__(self, location: Optional[SourceLocation], feature: str) -> None: - super().__init__(location, f"unsupported Python syntax: '{feature}'") + super().__init__(location, f"Unsupported Python syntax: '{feature}'.") self.feature = feature @@ -69,7 +69,7 @@ class UndefinedSymbolError(DSLError): sym_name: str def __init__(self, location: Optional[SourceLocation], name: str) -> None: - super().__init__(location, f"name '{name}' is not defined") + super().__init__(location, f"Name '{name}' is not defined.") self.sym_name = name @@ -77,7 +77,7 @@ class MissingAttributeError(DSLError): attr_name: str def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: - super().__init__(location, f"object does not have attribute '{attr_name}'") + super().__init__(location, f"Object does not have attribute '{attr_name}'.") self.attr_name = attr_name @@ -90,7 +90,7 @@ class MissingParameterAnnotationError(TypeError_): param_name: str def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: - super().__init__(location, f"parameter '{param_name}' is missing type annotations") + super().__init__(location, f"Parameter '{param_name}' is missing type annotations.") self.param_name = param_name @@ -100,7 +100,7 @@ class InvalidParameterAnnotationError(TypeError_): def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any) -> None: super().__init__( - location, f"parameter '{param_name}' has invalid type annotation '{type_}'" + location, f"Parameter '{param_name}' has invalid type annotation '{type_}'." ) self.param_name = param_name self.annotated_type = type_ diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index e2e6439e37..8b079bb8c1 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -22,7 +22,7 @@ class NodeYielder(ast.NodeTransformer): def apply(cls, node: ast.AST) -> ast.AST: result = list(cls().visit(node)) if len(result) != 1: - raise ValueError("AST was split or lost during the pass. Use `.visit()` instead.") + raise ValueError("AST was split or lost during the pass, use '.visit()' instead.") return result[0] def visit(self, node: ast.AST) -> Iterator[ast.AST]: diff --git a/src/gt4py/next/ffront/ast_passes/single_static_assign.py b/src/gt4py/next/ffront/ast_passes/single_static_assign.py index 4181d7f449..ee1e29a8e8 100644 --- a/src/gt4py/next/ffront/ast_passes/single_static_assign.py +++ b/src/gt4py/next/ffront/ast_passes/single_static_assign.py @@ -65,7 +65,7 @@ class _AssignmentTracker: def define(self, name: str) -> None: if name in self.names(): - raise ValueError(f"Variable {name} is already defined.") + raise ValueError(f"Variable '{name}' is already defined.") # -1 signifies a self._counts[name] = -1 diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8202cda6f5..4abd8f156a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -88,7 +88,7 @@ def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any raise NotImplementedError( f"Using closure vars with same name but different value " f"across functions is not implemented yet. \n" - f"Collisions: {'`, `'.join(collisions)}" + f"Collisions: '{', '.join(collisions)}'." ) all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) @@ -125,7 +125,7 @@ def is_cartesian_offset(o: FieldOffset): if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: raise ValueError( - "grid_type == GridType.CARTESIAN was requested, but unstructured `FieldOffset` or local `Dimension` was found." + "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." ) return deduced_grid_type if requested_grid_type is None else requested_grid_type @@ -147,7 +147,7 @@ def _field_constituents_shape_and_dims( elif isinstance(arg_type, ts.ScalarType): yield (None, []) else: - raise ValueError("Expected `FieldType` or `TupleType` thereof.") + raise ValueError("Expected 'FieldType' or 'TupleType' thereof.") # TODO(tehrengruber): Decide if and how programs can call other programs. As a @@ -208,7 +208,7 @@ def __post_init__(self): ] if misnamed_functions: raise RuntimeError( - f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}" + f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." ) undefined_symbols = [ @@ -218,7 +218,7 @@ def __post_init__(self): ] if undefined_symbols: raise RuntimeError( - f"The following closure variables are undefined: {', '.join(undefined_symbols)}" + f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) @functools.cached_property @@ -228,7 +228,7 @@ def __gt_allocator__( if self.backend: return self.backend.__gt_allocator__ else: - raise RuntimeError(f"Program {self} does not have a backend set.") + raise RuntimeError(f"Program '{self}' does not have a backend set.") def with_backend(self, backend: ppi.ProgramExecutor) -> Program: return dataclasses.replace(self, backend=backend) @@ -263,7 +263,7 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: """ for key in kwargs.keys(): if all(key != param.id for param in self.past_node.params): - raise TypeError(f"Keyword argument `{key}` is not a valid program parameter.") + raise TypeError(f"Keyword argument '{key}' is not a valid program parameter.") return ProgramWithBoundArgs( bound_args=kwargs, @@ -344,7 +344,7 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to `{self.past_node.id}`!") from err + raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) @@ -397,9 +397,10 @@ def _column_axis(self): ] raise TypeError( - "Only `ScanOperator`s defined on the same axis " - + "can be used in a `Program`, but found:\n" + "Only 'ScanOperator's defined on the same axis " + + "can be used in a 'Program', found:\n" + "\n".join(scanops_per_axis_strs) + + "." ) return iter(scanops_per_axis.keys()).__next__() @@ -436,7 +437,7 @@ def _process_args(self, args: tuple, kwargs: dict): # a better error message. for name in self.bound_args.keys(): if name in kwargs: - raise ValueError(f"Parameter `{name}` already set as a bound argument.") + raise ValueError(f"Parameter '{name}' already set as a bound argument.") type_info.accepts_args( new_type, @@ -445,10 +446,10 @@ def _process_args(self, args: tuple, kwargs: dict): raise_exception=True, ) except ValueError as err: - bound_arg_names = ", ".join([f"`{bound_arg}`" for bound_arg in self.bound_args.keys()]) + bound_arg_names = ", ".join([f"'{bound_arg}'" for bound_arg in self.bound_args.keys()]) raise TypeError( - f"Invalid argument types in call to program `{self.past_node.id}` with " - f"bound arguments {bound_arg_names}!" + f"Invalid argument types in call to program '{self.past_node.id}' with " + f"bound arguments '{bound_arg_names}'." ) from err full_args = [*args] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8230e35a35..93f17b1eb8 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -139,13 +139,16 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( - f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages + # TODO(havogt) find a strategy to unify parsing and embedded error messages + f"Either both or none can be tuple in '{true_field=}' and '{false_field=}'." ) if len(true_field) != len(false_field): raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return tuple( + where(mask, t, f) for t, f in zip(true_field, false_field) + ) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) diff --git a/src/gt4py/next/ffront/foast_introspection.py b/src/gt4py/next/ffront/foast_introspection.py index 805df465b8..404b99d1a0 100644 --- a/src/gt4py/next/ffront/foast_introspection.py +++ b/src/gt4py/next/ffront/foast_introspection.py @@ -73,4 +73,4 @@ def deduce_stmt_return_kind(node: foast.Stmt) -> StmtReturnKind: elif isinstance(node, (foast.Assign, foast.TupleTargetAssign)): return StmtReturnKind.NO_RETURN else: - raise AssertionError(f"Statements of type `{type(node).__name__}` not understood.") + raise AssertionError(f"Statements of type '{type(node).__name__}' not understood.") diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 9afd22de2c..0561a80659 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -56,7 +56,7 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) raise errors.MissingAttributeError(node.location, node.attr) - raise errors.DSLError(node.location, "attribute access only applicable to constants") + raise errors.DSLError(node.location, "Attribute access only applicable to constants.") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 95c9128f87..639e5ff009 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -53,7 +53,7 @@ def with_altered_scalar_kind( elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) else: - raise ValueError(f"Expected field or scalar type, but got {type_spec}.") + raise ValueError(f"Expected field or scalar type, got '{type_spec}'.") def construct_tuple_type( @@ -113,7 +113,9 @@ def promote_to_mask_type( item in input_type.dims for item in mask_type.dims ): return_dtype = input_type.dtype if isinstance(input_type, ts.FieldType) else input_type - return type_info.promote(input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype)) # type: ignore + return type_info.promote( + input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype) + ) # type: ignore else: return input_type @@ -148,7 +150,7 @@ def deduce_stmt_return_type( else: raise errors.DSLError( stmt.location, - f"If statement contains return statements with inconsistent types:" + "If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", ) return_type = return_types[0] or return_types[1] @@ -160,12 +162,12 @@ def deduce_stmt_return_type( elif isinstance(stmt, (foast.Assign, foast.TupleTargetAssign)): return_type = None else: - raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") + raise AssertionError(f"Nodes of type '{type(stmt).__name__}' not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: raise errors.DSLError( stmt.location, - f"If statement contains return statements with inconsistent types:" + "If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", ) @@ -179,7 +181,7 @@ def deduce_stmt_return_type( # If the node was constructed by the foast parsing we should never get here, but instead # we should have gotten an error there. raise AssertionError( - "Malformed block statement. Expected a return statement in this context, " + "Malformed block statement: expected a return statement in this context, " "but none was found. Please submit a bug report." ) @@ -195,7 +197,7 @@ def apply(cls, node: foast.LocatedNode) -> None: cls().visit(node, incomplete_nodes=incomplete_nodes) if incomplete_nodes: - raise AssertionError("FOAST expression is not fully typed.") + raise AssertionError("'FOAST' expression is not fully typed.") def visit_LocatedNode( self, node: foast.LocatedNode, *, incomplete_nodes: list[foast.LocatedNode] @@ -251,7 +253,7 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): raise errors.DSLError( node.location, - f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", + f"Function must return 'DataType', 'DeferredType', or 'VoidType', got '{return_type}'.", ) new_type = ts.FunctionType( pos_only_args=[], @@ -283,17 +285,17 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp if not isinstance(new_axis.type, ts.DimensionType): raise errors.DSLError( node.location, - f"Argument `axis` to scan operator `{node.id}` must be a dimension.", + f"Argument 'axis' to scan operator '{node.id}' must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: raise errors.DSLError( node.location, - f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", + f"Argument 'axis' to scan operator '{node.id}' must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: raise errors.DSLError( - node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." + node.location, f"Argument 'forward' to scan operator '{node.id}' must be a boolean." ) new_init = self.visit(node.init, **kwargs) if not all( @@ -302,8 +304,8 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp ): raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must " - f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", + f"Argument 'init' to scan operator '{node.id}' must " + "be an arithmetic type or a logical type or a composite of arithmetic and logical types.", ) new_definition = self.visit(node.definition, **kwargs) new_def_type = new_definition.type @@ -311,15 +313,15 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp if new_init.type != new_def_type.returns: raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must have same type as its return. " - f"Expected `{new_def_type.returns}`, but got `{new_init.type}`", + f"Argument 'init' to scan operator '{node.id}' must have same type as its return: " + f"expected '{new_def_type.returns}', got '{new_init.type}'.", ) elif new_init.type != carry_type: carry_arg_name = list(new_def_type.pos_or_kw_args.keys())[0] raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must have same type as `{carry_arg_name}` argument. " - f"Expected `{carry_type}`, but got `{new_init.type}`", + f"Argument 'init' to scan operator '{node.id}' must have same type as '{carry_arg_name}' argument: " + f"expected '{carry_type}', got '{new_init.type}'.", ) new_type = ts_ffront.ScanOperatorType( @@ -339,7 +341,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise errors.DSLError(node.location, f"Undeclared symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared symbol '{node.id}'.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -362,9 +364,9 @@ def visit_TupleTargetAssign( targets: TargetType = node.targets indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) - if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: + if not any(isinstance(i, tuple) for i in indices) and len(targets) != num_elts: raise errors.DSLError( - node.location, f"Too many values to unpack (expected {len(indices)})." + node.location, f"Too many values to unpack (expected {len(targets)})." ) new_targets: TargetType = [] @@ -396,7 +398,7 @@ def visit_TupleTargetAssign( new_targets.append(new_target) else: raise errors.DSLError( - node.location, f"Assignment value must be of type tuple! Got: {values.type}" + node.location, f"Assignment value must be of type tuple, got '{values.type}'." ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -416,15 +418,14 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if not isinstance(new_node.condition.type, ts.ScalarType): raise errors.DSLError( node.location, - "Condition for `if` must be scalar. " - f"But got `{new_node.condition.type}` instead.", + "Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: raise errors.DSLError( node.location, - "Condition for `if` must be of boolean type. " - f"But got `{new_node.condition.type}` instead.", + "Condition for 'if' must be of boolean type, " + f"got '{new_node.condition.type}' instead.", ) for sym in node.annex.propagated_symbols.keys(): @@ -433,8 +434,8 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ): raise errors.DSLError( node.location, - f"Inconsistent types between two branches for variable `{sym}`. " - f"Got types `{true_type}` and `{false_type}.", + f"Inconsistent types between two branches for variable '{sym}': " + f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) symtable[sym].type = new_node.annex.propagated_symbols[ @@ -455,8 +456,8 @@ def visit_Symbol( raise errors.DSLError( node.location, ( - "type inconsistency: expression was deduced to be " - f"of type {refine_type}, instead of the expected type {node.type}" + "Type inconsistency: expression was deduced to be " + f"of type '{refine_type}', instead of the expected type '{node.type}'." ), ) new_node: foast.Symbol = foast.Symbol( @@ -490,7 +491,7 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = new_value.type case _: raise errors.DSLError( - new_value.location, "Could not deduce type of subscript expression!" + new_value.location, "Could not deduce type of subscript expression." ) return foast.Subscript( @@ -531,13 +532,13 @@ def _deduce_ternaryexpr_type( if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): raise errors.DSLError( condition.location, - f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", + f"Condition is of type '{condition.type}', should be of type 'bool'.", ) if true_expr.type != false_expr.type: raise errors.DSLError( node.location, - f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", + f"Left and right types are not the same: '{true_expr.type}' and '{false_expr.type}'", ) return true_expr.type @@ -556,7 +557,7 @@ def _deduce_compare_type( for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -571,8 +572,8 @@ def _deduce_compare_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left.type}` and `{right.type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left.type}' and '{right.type}' to common type" + f" in call to '{node.op}'.", ) from ex def _deduce_binop_type( @@ -594,7 +595,7 @@ def _deduce_binop_type( for arg in (left, right): if not is_compatible(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) @@ -608,7 +609,7 @@ def _deduce_binop_type( ): raise errors.DSLError( arg.location, - f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: @@ -616,8 +617,8 @@ def _deduce_binop_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left_type}` and `{right_type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", ) from ex def _check_operand_dtypes_match( @@ -627,7 +628,7 @@ def _check_operand_dtypes_match( if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): raise errors.DSLError( node.location, - f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", + f"Incompatible datatypes in operator '{node.op}': '{left.type}' and '{right.type}'.", ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: @@ -644,7 +645,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: if not is_compatible(new_operand.type): raise errors.DSLError( node.location, - f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", + f"Incompatible type for unary operator '{node.op}': '{new_operand.type}'.", ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type @@ -674,13 +675,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise errors.DSLError(node.location, "Functions can only be called directly!") + raise errors.DSLError(node.location, "Functions can only be called directly.") elif isinstance(new_func.type, ts.FieldType): pass else: raise errors.DSLError( node.location, - f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", + f"Expression of type '{new_func.type}' is not callable, must be a 'Function', 'FieldOperator', 'ScanOperator' or 'Field'.", ) # ensure signature is valid @@ -693,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to `{new_func}`!" + node.location, f"Invalid argument types in call to '{new_func}'." ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -727,7 +728,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: func_name = cast(foast.Name, node.func).id # validate arguments - error_msg_preamble = f"Incompatible argument in call to `{func_name}`." + error_msg_preamble = f"Incompatible argument in call to '{func_name}'." error_msg_for_validator = { type_info.is_arithmetic: "an arithmetic", type_info.is_floating_point: "a floating point", @@ -741,13 +742,13 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: elif func_name in fbuiltins.BINARY_MATH_NUMBER_BUILTIN_NAMES: arg_validator = type_info.is_arithmetic else: - raise AssertionError(f"Unknown math builtin `{func_name}`.") + raise AssertionError(f"Unknown math builtin '{func_name}'.") error_msgs = [] for i, arg in enumerate(node.args): if not arg_validator(arg.type): error_msgs.append( - f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." + f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, got '{arg.type}'." ) if error_msgs: raise errors.DSLError( @@ -756,7 +757,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): - print(f"Warning: return type of {func_name} might be inconsistent (not implemented).") + print(f"Warning: return type of '{func_name}' might be inconsistent (not implemented).") # deduce return type return_type: Optional[ts.FieldType | ts.ScalarType] = None @@ -777,7 +778,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError(node.location, error_msg_preamble) from ex else: - raise AssertionError(f"Unknown math builtin `{func_name}`.") + raise AssertionError(f"Unknown math builtin '{func_name}'.") return foast.Call( func=node.func, @@ -796,9 +797,9 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) raise errors.DSLError( node.location, - f"Incompatible field argument in call to `{str(node.func)}`. " - f"Expected a field with dimension {reduction_dim}, but got " - f"{field_dims_str}.", + f"Incompatible field argument in call to '{str(node.func)}'. " + f"Expected a field with dimension '{reduction_dim}', got " + f"'{field_dims_str}'.", ) return_type = ts.FieldType( dims=[dim for dim in field_type.dims if dim != reduction_dim], @@ -834,7 +835,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: ]: raise errors.DSLError( node.location, - f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", + f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.", ) return_type = type_info.apply_to_primitive_constituents( @@ -860,16 +861,16 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_integral(arg_1): raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. " - f"Excepted integer for offset field dtype, but got {arg_1.dtype}" + f"Incompatible argument in call to '{str(node.func)}': " + f"expected integer for offset field dtype, got '{arg_1.dtype}'. " f"{node.location}", ) if arg_0.source not in arg_1.dims: raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. " - f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " + f"Incompatible argument in call to '{str(node.func)}': " + f"'{arg_0.source}' not in list of offset field dimensions '{arg_1.dims}'. " f"{node.location}", ) @@ -889,8 +890,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_logical(mask_type): raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. Expected " - f"a field with dtype `bool`, but got `{mask_type}`.", + f"Incompatible argument in call to '{str(node.func)}': expected " + f"a field with dtype 'bool', got '{mask_type}'.", ) try: @@ -907,8 +908,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: ): raise errors.DSLError( node.location, - f"Return arguments need to be of same type in {str(node.func)}, but got: " - f"{node.args[1].type} and {node.args[2].type}", + f"Return arguments need to be of same type in '{str(node.func)}', got " + f"'{node.args[1].type}' and '{node.args[2].type}'.", ) else: true_branch_fieldtype = cast(ts.FieldType, true_branch_type) @@ -919,7 +920,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`.", + f"Incompatible argument in call to '{str(node.func)}'.", ) from ex return foast.Call( @@ -937,8 +938,8 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): raise errors.DSLError( node.location, - f"Incompatible broadcast dimension type in {str(node.func)}. Expected " - f"all broadcast dimensions to be of type Dimension.", + f"Incompatible broadcast dimension type in '{str(node.func)}': expected " + f"all broadcast dimensions to be of type 'Dimension'.", ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] @@ -946,8 +947,8 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): raise errors.DSLError( node.location, - f"Incompatible broadcast dimensions in {str(node.func)}. Expected " - f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", + f"Incompatible broadcast dimensions in '{str(node.func)}': expected " + f"broadcast dimension(s) '{set(arg_dims).difference(set(broadcast_dims))}' missing", ) return_type = ts.FieldType( diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 3b81c85265..9275cdda95 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -110,7 +110,7 @@ def apply(cls, node: foast.LocatedNode, **kwargs) -> str: # type: ignore[overri node_type_name = type(node).__name__ if not hasattr(cls, node_type_name) and not hasattr(cls, f"visit_{node_type_name}"): raise NotImplementedError( - f"Pretty printer does not support nodes of type " f"`{node_type_name}`." + f"Pretty printer does not support nodes of type '{node_type_name}'." ) return cls().visit(node, **kwargs) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 3030c03fd1..c4d518d279 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -230,7 +230,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: dtype = type_info.extract_dtype(node.type) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"{node.op} is only supported on `bool`s.") + raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._map("not_", node.operand) return self._map( @@ -313,7 +313,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) raise AssertionError( - f"Call to object of type {type(node.func.type).__name__} not understood." + f"Call to object of type '{type(node.func.type).__name__}' not understood." ) def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -371,7 +371,9 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr: im.literal(str(bool(source_type(node.args[0].value))), "bool") ) return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) - raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}") + raise FieldOperatorLoweringError( + f"Encountered a type cast, which is not supported: {node}." + ) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; @@ -388,7 +390,7 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: elif isinstance(type_, ts.ScalarType): typename = type_.kind.name.lower() return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type {type_}.") + raise ValueError(f"Unsupported literal type '{type_}'.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: return self._make_literal(node.value, node.type) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index c7c4c3a23f..0fd263308e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -107,8 +107,9 @@ def _postprocess_dialect_ast( if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented raise errors.DSLError( foast_node.location, - f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - f", but got `{annotated_return_type}`.", + "Annotated return type does not match deduced return type: expected " + f"'{foast_node.type.returns}'" # type: ignore[union-attr] # revisit when 'type_info.return_type' is implemented + f", got '{annotated_return_type}'.", ) return foast_node @@ -167,7 +168,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise errors.DSLError(loc, "Function is expected to return a value.") + raise errors.DSLError(loc, "'Function' is expected to return a value.") return foast.FunctionDefinition( id=node.name, @@ -224,7 +225,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise errors.DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "Can only assign to names.") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -246,7 +247,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise errors.DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "Can only assign to names.") if node.annotation is not None: assert isinstance( @@ -281,14 +282,14 @@ def _match_index(node: ast.expr) -> int: return -node.operand.value if isinstance(node.op, ast.UAdd): return node.operand.value - raise ValueError(f"Not an index: {node}") + raise ValueError(f"Not an index: '{node}'.") def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: raise errors.DSLError( - self.get_location(node.slice), "expected an integral index" + self.get_location(node.slice), "eXpected an integral index." ) from None return foast.Subscript( @@ -310,7 +311,7 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: loc = self.get_location(node) if not node.value: - raise errors.DSLError(loc, "must return a value, not None") + raise errors.DSLError(loc, "Must return a value, not None") return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: @@ -442,11 +443,11 @@ def _verify_builtin_type_constructor(self, node: ast.Call): if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): raise errors.DSLError( self.get_location(node), - f"{self._func_name(node)}() only takes literal arguments!", + f"'{self._func_name(node)}()' only takes literal arguments.", ) def _func_name(self, node: ast.Call) -> str: - return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. + return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call: # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? @@ -468,7 +469,7 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: type_ = type_translation.from_value(node.value) except ValueError: raise errors.DSLError( - loc, f"constants of type {type(node.value)} are not permitted" + loc, f"Constants of type {type(node.value)} are not permitted." ) from None return foast.Constant( diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 7b04e90902..5b4dd934b9 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -129,7 +129,7 @@ def visit_Call(self, node: ast.Call) -> past.Call: new_func = self.visit(node.func) if not isinstance(new_func, past.Name): raise errors.DSLError( - loc, "functions must be referenced by their name in function calls" + loc, "Functions must be referenced by their name in function calls." ) return past.Call( @@ -166,7 +166,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) - raise errors.DSLError(loc, "unary operators are only applicable to literals") + raise errors.DSLError(loc, "Unary operators are only applicable to literals.") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index ed3bdae3ff..fc353d64e4 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -33,7 +33,7 @@ def _ensure_no_sliced_field(entry: past.Expr): For example, if argument is of type past.Subscript, this function will throw an error as both slicing and domain are being applied """ if not isinstance(entry, past.Name) and not isinstance(entry, past.TupleExpr): - raise ValueError("Either only domain or slicing allowed") + raise ValueError("Either only domain or slicing allowed.") elif isinstance(entry, past.TupleExpr): for param in entry.elts: _ensure_no_sliced_field(param) @@ -57,20 +57,18 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), ): raise ValueError( - f"Only calls `FieldOperator`s and `ScanOperator`s " - f"allowed in `Program`, but got `{new_func.type}`." + f"Only calls to 'FieldOperators' and 'ScanOperators' " + f"allowed in 'Program', got '{new_func.type}'." ) if "out" not in new_kwargs: - raise ValueError("Missing required keyword argument(s) `out`.") + raise ValueError("Missing required keyword argument 'out'.") if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) domain_kwarg = new_kwargs["domain"] if not isinstance(domain_kwarg, past.Dict): - raise ValueError( - f"Only Dictionaries allowed in domain, but got `{type(domain_kwarg)}`." - ) + raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.") if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: raise ValueError("Empty domain not allowed.") @@ -78,18 +76,18 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): for dim in domain_kwarg.keys_: if not isinstance(dim.type, ts.DimensionType): raise ValueError( - f"Only Dimension allowed in domain dictionary keys, but got `{dim}` which is of type `{dim.type}`." + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." ) for domain_values in domain_kwarg.values_: if len(domain_values.elts) != 2: raise ValueError( - f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." ) if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( domain_values.elts[1] ): raise ValueError( - f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}." + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) @@ -149,7 +147,7 @@ def _deduce_binop_type( for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.ScalarType, left.type) @@ -163,7 +161,7 @@ def _deduce_binop_type( ): raise errors.DSLError( arg.location, - f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: @@ -171,8 +169,8 @@ def _deduce_binop_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left_type}` and `{right_type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", ) from ex def visit_BinOp(self, node: past.BinOp, **kwargs) -> past.BinOp: @@ -214,24 +212,24 @@ def visit_Call(self, node: past.Call, **kwargs): ) if operator_return_type != new_kwargs["out"].type: raise ValueError( - f"Expected keyword argument `out` to be of " - f"type {operator_return_type}, but got " - f"{new_kwargs['out'].type}." + "Expected keyword argument 'out' to be of " + f"type '{operator_return_type}', got " + f"'{new_kwargs['out'].type}'." ) elif new_func.id in ["minimum", "maximum"]: if new_args[0].type != new_args[1].type: raise ValueError( - f"First and second argument in {new_func.id} must be the same type." - f"Got `{new_args[0].type}` and `{new_args[1].type}`." + f"First and second argument in '{new_func.id}' must be of the same type." + f"Got '{new_args[0].type}' and '{new_args[1].type}'." ) return_type = new_args[0].type else: raise AssertionError( - "Only calls `FieldOperator`s, `ScanOperator`s or minimum and maximum builtins allowed" + "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex return past.Call( func=new_func, @@ -244,6 +242,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise errors.DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared or untyped symbol '{node.id}'.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 2c5dfc6e2f..709912077b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -37,7 +37,7 @@ def _flatten_tuple_expr( for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise ValueError("Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed.") + raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -174,7 +174,7 @@ def _visit_slice_bound( else: lowered_bound = self.visit(slice_bound, **kwargs) else: - raise AssertionError("Expected `None` or `past.Constant`.") + raise AssertionError("Expected 'None' or 'past.Constant'.") return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: @@ -189,8 +189,8 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: ) else: raise ValueError( - "Unexpected `out` argument. Must be a `past.Name`, `past.Subscript`" - " or a `past.TupleExpr` thereof." + "Unexpected 'out' argument. Must be a 'past.Name', 'past.Subscript'" + " or a 'past.TupleExpr' thereof." ) def _construct_itir_domain_arg( @@ -209,9 +209,9 @@ def _construct_itir_domain_arg( for out_field_type in out_field_types ): raise AssertionError( - f"Expected constituents of `{out_field.id}` argument to be" - f" fields defined on the same dimensions. This error should be " - f" caught in type deduction already." + f"Expected constituents of '{out_field.id}' argument to be" + " fields defined on the same dimensions. This error should be " + " caught in type deduction already." ) for dim_i, dim in enumerate(out_dims): @@ -232,7 +232,7 @@ def _construct_itir_domain_arg( ) if dim.kind == DimensionKind.LOCAL: - raise ValueError(f"Dimension {dim.value} must not be local.") + raise ValueError(f"Dimension '{dim.value}' must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -259,8 +259,8 @@ def _construct_itir_initialized_domain_arg( keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: raise ValueError( - f"Dimensions in out field and field domain are not equivalent" - f"Expected {dim}, but got {keys_dims_types} " + "Dimensions in out field and field domain are not equivalent:" + f"expected '{dim}', got '{keys_dims_types}'." ) return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] @@ -277,13 +277,13 @@ def _compute_field_slice(node: past.Subscript): out_field_slice_ = [node.slice_] else: raise AssertionError( - "Unexpected `out` argument. Must be tuple of slices or slice expression." + "Unexpected 'out' argument, must be tuple of slices or slice expression." ) node_dims_ls = cast(ts.FieldType, node.type).dims assert isinstance(node_dims_ls, list) if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls): raise ValueError( - f"Too many indices for field {out_field_name}: field is {len(node_dims_ls)}" + f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}" f"-dimensional, but {len(out_field_slice_)} were indexed." ) return out_field_slice_ @@ -321,7 +321,11 @@ def _visit_stencil_call_out_arg( isinstance(field, past.Subscript) for field in flattened ), "Incompatible field in tuple: either all fields or no field must be sliced." assert all( - concepts.eq_nonlocated(first_field.slice_, field.slice_) for field in flattened # type: ignore[union-attr] # mypy cannot deduce type + concepts.eq_nonlocated( + first_field.slice_, + field.slice_, # type: ignore[union-attr] # mypy cannot deduce type + ) + for field in flattened ), "Incompatible field in tuple: all fields must be sliced in the same way." field_slice = self._compute_field_slice(first_field) first_field = first_field.value @@ -332,7 +336,7 @@ def _visit_stencil_call_out_arg( ) else: raise AssertionError( - "Unexpected `out` argument. Must be a `past.Subscript`, `past.Name` or `past.TupleExpr` node." + "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." ) def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: @@ -340,7 +344,7 @@ def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: match node.type.kind: case ts.ScalarKind.STRING: raise NotImplementedError( - f"Scalars of kind {node.type.kind} not supported currently." + f"Scalars of kind '{node.type.kind}' not supported currently." ) typename = node.type.kind.name.lower() return itir.Literal(value=str(node.value), type=typename) @@ -373,5 +377,5 @@ def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: ) else: raise AssertionError( - "Only `minimum` and `maximum` builtins supported supported currently." + "Only 'minimum' and 'maximum' builtins supported supported currently." ) diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index 17b2050b1b..baf3037d5e 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -37,7 +37,7 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) if not filename: raise ValueError( - "Can not create field operator from a function that is not in a source file!" + "Can not create field operator from a function that is not in a source file." ) source_lines, line_offset = inspect.getsourcelines(func) source_code = textwrap.dedent(inspect.getsource(func)) @@ -47,7 +47,7 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: return SourceDefinition(source_code, filename, line_offset - 1, column_offset) except OSError as err: - raise ValueError(f"Can not get source code of passed function ({func})") from err + raise ValueError(f"Can not get source code of passed function '{func}'.") from err def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) -> SymbolNames: @@ -55,13 +55,13 @@ def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) mod_st = symtable.symtable(source, filename, "exec") except SyntaxError as err: raise ValueError( - f"Unexpected error when parsing provided source code (\n{source}\n)" + f"Unexpected error when parsing provided source code: \n{source}\n" ) from err assert mod_st.get_type() == "module" if len(children := mod_st.get_children()) != 1: raise ValueError( - f"Sources with multiple function definitions are not yet supported (\n{source}\n)" + f"Sources with multiple function definitions are not yet supported: \n{source}\n" ) assert children[0].get_type() == "function" diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 7f56f5d92b..affae8fbca 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -51,7 +51,7 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: if type_info.extract_dtype(param_el) == type_info.extract_dtype(arg_el): return param_el else: - raise ValueError(f"{arg_el} is not compatible with {param_el}.") + raise ValueError(f"'{arg_el}' is not compatible with '{param_el}'.") return arg_el return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True) diff --git a/src/gt4py/next/iterator/dispatcher.py b/src/gt4py/next/iterator/dispatcher.py index b2ca39df04..626c51ed1c 100644 --- a/src/gt4py/next/iterator/dispatcher.py +++ b/src/gt4py/next/iterator/dispatcher.py @@ -57,7 +57,7 @@ def register_key(self, key): def push_key(self, key): if key not in self._funs: - raise RuntimeError(f"Key {key} not registered") + raise RuntimeError(f"Key '{key}' not registered.") self.key_stack.append(key) def pop_key(self): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b00e53bfd9..a4f32929db 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -238,7 +238,7 @@ def _validate_kstart(self, args): set(arg.kstart for arg in args if isinstance(arg, Column)) - {self.kstart} ): raise ValueError( - "Incompatible Column.kstart: it should be '{self.kstart}' but found other values: {wrong_kstarts}" + "Incompatible 'Column.kstart': it should be '{self.kstart}' but found other values: {wrong_kstarts}." ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Column: @@ -486,7 +486,7 @@ def promote_scalars(val: CompositeOfScalarOrField): return constant_field(val) else: raise ValueError( - f"Expected a `Field` or a number (`float`, `np.int64`, ...), but got {val_type}." + f"Expected a 'Field' or a number ('float', 'np.int64', ...), got '{val_type}'." ) @@ -566,7 +566,7 @@ def execute_shift( return new_pos - raise AssertionError("Unknown object in `offset_provider`") + raise AssertionError("Unknown object in 'offset_provider'.") def _is_list_of_complete_offsets( @@ -878,7 +878,7 @@ def make_in_iterator( return SparseListIterator(it, sparse_dimensions[0]) else: raise NotImplementedError( - f"More than one local dimension is currently not supported, got {sparse_dimensions}" + f"More than one local dimension is currently not supported, got {sparse_dimensions}." ) else: return it @@ -925,7 +925,7 @@ def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): self._ndarrayfield[self._translate_named_indices(named_indices)] = value else: - raise RuntimeError("Assigment into a non-mutable Field.") + raise RuntimeError("Assigment into a non-mutable Field is not allowed.") @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1023,7 +1023,7 @@ def np_as_located_field( def _maker(a) -> common.Field: if a.ndim != len(axes): - raise TypeError("ndarray.ndim incompatible with number of given dimensions") + raise TypeError("'ndarray.ndim' is incompatible with number of given dimensions.") ranges = [] for d, s in zip(axes, a.shape): offset = origin.get(d, 0) @@ -1071,7 +1071,7 @@ def dtype(self) -> core_defs.Int32DType: @property def ndarray(self) -> core_defs.NDArrayObject: - raise AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get 'ndarray' of an infinite 'Field'.") def asnumpy(self) -> np.ndarray: raise NotImplementedError() @@ -1190,7 +1190,7 @@ def codomain(self) -> type[core_defs.ScalarT]: @property def ndarray(self) -> core_defs.NDArrayObject: - raise AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get 'ndarray' of an infinite 'Field'.") def asnumpy(self) -> np.ndarray: raise NotImplementedError() @@ -1440,7 +1440,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: if isinstance(field, tuple): if len(field) != len(value): raise RuntimeError( - f"Tuple of incompatible size, expected tuple of len={len(field)}, got len={len(value)}" + f"Tuple of incompatible size, expected tuple of 'len={len(field)}', got 'len={len(value)}'." ) for f, v in zip(field, value): _tuple_assign(f, v, named_indices) @@ -1459,7 +1459,7 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if not isinstance(value, tuple): - raise RuntimeError(f"Value needs to be tuple, got `{value}`.") + raise RuntimeError(f"Value needs to be tuple, got '{value}'.") _tuple_assign(self.data, value, named_indices) @@ -1503,13 +1503,13 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: if isinstance(domain, runtime.CartesianDomain): if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): raise RuntimeError( - "Got a `CartesianDomain`, but found a `Connectivity` in `offset_provider`, expected `UnstructuredDomain`." + "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): if "offset_provider" not in kwargs: - raise RuntimeError("offset_provider not provided") + raise RuntimeError("'offset_provider' not provided.") offset_provider = kwargs["offset_provider"] @@ -1523,7 +1523,7 @@ def closure( _validate_domain(domain_, kwargs["offset_provider"]) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (common.is_field(out) or is_tuple_of_field(out)): - raise TypeError("Out needs to be a located field.") + raise TypeError("'Out' needs to be a located field.") column_range = None column: Optional[ColumnDescriptor] = None diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 535648cc47..e6ee20e227 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -49,13 +49,13 @@ class Sym(Node): # helper @datamodels.validator("kind") def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): if value and value not in ["Iterator", "Value"]: - raise ValueError(f"Invalid kind `{value}`, must be one of `Iterator`, `Value`.") + raise ValueError(f"Invalid kind '{value}', must be one of 'Iterator', 'Value'.") @datamodels.validator("dtype") def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): if value and value[0] not in TYPEBUILTINS: raise ValueError( - f"Invalid dtype `{value}`, must be one of `{'`, `'.join(TYPEBUILTINS)}`." + f"Invalid dtype '{value}', must be one of '{', '.join(TYPEBUILTINS)}'." ) @@ -71,7 +71,7 @@ class Literal(Expr): @datamodels.validator("type") def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if value not in TYPEBUILTINS: - raise ValueError(f"{value} is not a valid builtin type.") + raise ValueError(f"'{value}' is not a valid builtin type.") class NoneLiteral(Expr): @@ -115,7 +115,7 @@ class StencilClosure(Node): @datamodels.validator("output") def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to `make_tuple` allowed.") + raise ValueError("Only FunCall to 'make_tuple' allowed.") UNARY_MATH_NUMBER_BUILTINS = {"abs"} diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f7086ada0c..94a2646422 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -295,7 +295,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Literal(value='True', type='bool') """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 - raise ValueError(f"Value must be a scalar, but got {type(val).__name__}") + raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.") # At the time this has been written the iterator module has its own type system that is # uncoupled from the one used in the frontend. However since we decided to eventually replace diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ffc00e474b..e12ae84dbc 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -96,7 +96,7 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): backend(self.itir(*args, **kwargs), *args, **kwargs) else: if fendef_embedded is None: - raise RuntimeError("Embedded execution is not registered") + raise RuntimeError("Embedded execution is not registered.") fendef_embedded(self.function, *args, **kwargs) def format_itir(self, *args, formatter: ProgramFormatter, **kwargs) -> str: diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d1f6bba8d6..30fec1f9fd 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -164,7 +164,7 @@ def make_node(o): return NoneLiteral() if hasattr(o, "fun"): return SymRef(id=o.fun.__name__) - raise NotImplementedError(f"Cannot handle {o}") + raise NotImplementedError(f"Cannot handle '{o}'.") def trace_function_call(fun, *, args=None): @@ -269,7 +269,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: # the last parameter info might also be a keyword or variadic keyword argument, but # they are not supported. raise NotImplementedError( - "Only `POSITIONAL_OR_KEYWORD` or `VAR_POSITIONAL` parameters are supported." + "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) param_info = param_infos[-1] @@ -279,7 +279,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: param_name = param_info.name else: raise NotImplementedError( - "Only `POSITIONAL_OR_KEYWORD` or `VAR_POSITIONAL` parameters are supported." + "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) kind, dtype = None, None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index cc70e11413..034a39d68f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -123,7 +123,7 @@ def generic_visit(self, *args, **kwargs): depth = kwargs.pop("depth") return super().generic_visit(*args, depth=depth + 1, **kwargs) - def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. + def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node): return super().visit(node, **kwargs) @@ -289,7 +289,7 @@ def extract_subexpression( # `_subexpr_2`: `x + y + (x + y)` raise NotImplementedError( "Results of the current implementation not meaningful for " - "`deepest_expr_first == True` and `once_only == True`." + "'deepest_expr_first == True' and 'once_only == True'." ) ignored_children = False diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e2feb79c44..2e05391634 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -68,7 +68,7 @@ def _inline_into_scan(ir, *, max_iter=10): break ir = inlined else: - raise RuntimeError(f"Inlining into scan did not converge with {max_iter} iterations.") + raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") return ir @@ -117,7 +117,7 @@ def apply_common_transforms( break ir = inlined else: - raise RuntimeError("Inlining lift and lambdas did not converge.") + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 60a5db7e96..861052bb25 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -81,7 +81,7 @@ def _get_connectivity( ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a `reduce` object, i.e. `reduce(...)(...)`.") + raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): @@ -90,11 +90,11 @@ def _get_connectivity( connectivities.append(conn) if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of reduce.") + raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to reduce have incompatible partial shifts.") + raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") return connectivities[0] diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 14f3e95e10..2375118cd1 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -74,7 +74,7 @@ def from_elems(cls: typing.Type[T], *elems: Type) -> typing.Union[T, EmptyTuple] def __iter__(self) -> abc.Iterator[Type]: yield self.front if not isinstance(self.others, (Tuple, EmptyTuple)): - raise ValueError(f"Can not iterate over partially defined tuple {self}") + raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") yield from self.others def __len__(self) -> int: @@ -286,7 +286,7 @@ def handle_constraint( if self.name != other.name: raise TypeError( - f"Can not satisfy constraint on primitive types: {self.name} ≡ {other.name}" + f"Can not satisfy constraint on primitive types: '{self.name}' ≡ '{other.name}'." ) return True @@ -300,7 +300,7 @@ def handle_constraint( self, other: Type, add_constraint: abc.Callable[[Type, Type], None] ) -> bool: if isinstance(other, UnionPrimitive): - raise AssertionError("`UnionPrimitive` may only appear on one side of a constraint.") + raise AssertionError("'UnionPrimitive' may only appear on one side of a constraint.") if not isinstance(other, Primitive): return False @@ -551,7 +551,8 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): current_loc_out = current_loc_in for arg in shift_args: if not isinstance(arg, ir.OffsetLiteral): - continue # probably some dynamically computed offset, thus we assume it’s a number not an axis and just ignore it (see comment below) + # probably some dynamically computed offset, thus we assume it’s a number not an axis and just ignore it (see comment below) + continue offset = arg.value if isinstance(offset, int): continue # ignore ‘application’ of (partial) shifts @@ -639,7 +640,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: elif node.id in ir.GRAMMAR_BUILTINS: raise TypeError( f"Builtin '{node.id}' is only allowed as applied/called function by the type " - f"inference." + "inference." ) elif node.id in ir.TYPEBUILTINS: # TODO(tehrengruber): Implement propagating types of values referring to types, e.g. @@ -649,10 +650,10 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: # `typing.Type`. raise NotImplementedError( f"Type builtin '{node.id}' is only supported as literal argument by the " - f"type inference." + "type inference." ) else: - raise NotImplementedError(f"Missing type definition for builtin '{node.id}'") + raise NotImplementedError(f"Missing type definition for builtin '{node.id}'.") elif node.id in symtable: sym_decl = symtable[node.id] assert isinstance(sym_decl, TYPED_IR_NODES) @@ -696,13 +697,13 @@ def _visit_make_tuple(self, node: ir.FunCall, **kwargs) -> Type: def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: # Calls to `tuple_get` are handled as being part of the grammar, not as function calls. if len(node.args) != 2: - raise TypeError("`tuple_get` requires exactly two arguments.") + raise TypeError("'tuple_get' requires exactly two arguments.") if ( not isinstance(node.args[0], ir.Literal) or node.args[0].type != ir.INTEGER_INDEX_BUILTIN ): raise TypeError( - f"The first argument to `tuple_get` must be a literal of type `{ir.INTEGER_INDEX_BUILTIN}`." + f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'." ) self.visit(node.args[0], **kwargs) # visit index so that its type is collected idx = int(node.args[0].value) @@ -725,9 +726,9 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: if len(node.args) != 2: - raise TypeError("`neighbors` requires exactly two arguments.") + raise TypeError("'neighbors' requires exactly two arguments.") if not (isinstance(node.args[0], ir.OffsetLiteral) and isinstance(node.args[0].value, str)): - raise TypeError("The first argument to `neighbors` must be an `OffsetLiteral` tag.") + raise TypeError("The first argument to 'neighbors' must be an 'OffsetLiteral' tag.") # Visit arguments such that their type is also inferred self.visit(node.args, **kwargs) @@ -766,11 +767,11 @@ def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: if len(node.args) != 2: - raise TypeError("`cast_` requires exactly two arguments.") + raise TypeError("'cast_' requires exactly two arguments.") val_arg_type = self.visit(node.args[0], **kwargs) type_arg = node.args[1] if not isinstance(type_arg, ir.SymRef) or type_arg.id not in ir.TYPEBUILTINS: - raise TypeError("The second argument to `cast_` must be a type literal.") + raise TypeError("The second argument to 'cast_' must be a type literal.") size = TypeVar.fresh() @@ -964,7 +965,7 @@ def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: and child_node.id in ir.GRAMMAR_BUILTINS | ir.TYPEBUILTINS ): raise AssertionError( - f"Expected a type to be inferred for node `{child_node}`, but none was found." + f"Expected a type to be inferred for node '{child_node}', but none was found." ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 5d54512bd0..bfb3b0d474 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -206,7 +206,7 @@ def create_bindings( """ if program_source.language not in [languages.Cpp, languages.Cuda]: raise ValueError( - f"Can only create bindings for C++ program sources, received {program_source.language}." + f"Can only create bindings for C++ program sources, received '{program_source.language}'." ) wrapper_name = program_source.entry_point.name + "_wrapper" diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 5ea4ba0519..2c0511ebf4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -101,7 +101,7 @@ def visit_FindDependency(self, dep: FindDependency): return f"find_package(GridTools REQUIRED PATHS {gridtools_cpp.get_cmake_dir()} NO_DEFAULT_PATH)" case _: - raise ValueError("Library {name} is not supported".format(name=dep.name)) + raise ValueError(f"Library '{dep.name}' is not supported") def visit_LinkDependency(self, dep: LinkDependency): # TODO(ricoh): do not add more libraries here @@ -115,7 +115,7 @@ def visit_LinkDependency(self, dep: LinkDependency): case "gridtools_gpu": lib_name = "GridTools::fn_gpu" case _: - raise ValueError("Library {name} is not supported".format(name=dep.name)) + raise ValueError(f"Library '{dep.name}' is not supported") cfg = "" if dep.name == "nanobind": diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index dacb444207..9fd20b16e2 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -80,7 +80,7 @@ def __call__( if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): raise CompilationError( - "On-the-fly compilation unsuccessful for {inp.source_module.entry_point.name}!" + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) return getattr( diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 0370b5eeb3..a21bc83c0b 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -59,7 +59,7 @@ class ProgramSource(Generic[SrcL, SettingT]): def __post_init__(self): if not isinstance(self.language_settings, self.language.settings_class): raise TypeError( - f"Wrong language settings type for {self.language}, must be subclass of {self.language.settings_class}" + f"Wrong language settings type for '{self.language}', must be subclass of '{self.language.settings_class}'." ) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 6b6b91a310..ed8b768972 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -80,7 +80,7 @@ def replace(self, **kwargs: Any) -> Self: TypeError: If `self` is not a dataclass. """ if not dataclasses.is_dataclass(self): - raise TypeError(f"{self.__class__} is not a dataclass") + raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? @@ -242,7 +242,9 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] + hash_function: Callable[[StartT], HashT] = dataclasses.field( + default=hash + ) # type: ignore[assignment] _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index f412386bb3..74fbbfc93f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -99,7 +99,7 @@ def _get_connectivity( ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a `reduce` object, i.e. `reduce(...)(...)`.") + raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): @@ -108,11 +108,11 @@ def _get_connectivity( connectivities.append(conn) if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of reduce.") + raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to reduce have incompatible partial shifts.") + raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") return connectivities[0] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 7bf310f4e1..4abdaa6eea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -135,7 +135,7 @@ def _process_connectivity_args( if isinstance(connectivity, Connectivity): if connectivity.index_type not in [np.int32, np.int64]: raise ValueError( - "Neighbor table indices must be of type `np.int32` or `np.int64`." + "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) # parameter @@ -165,8 +165,8 @@ def _process_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider `{name}` to be a `Connectivity` or `Dimension`, " - f"but got {type(connectivity).__name__}." + f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " + f"got '{type(connectivity).__name__}'." ) return parameters, arg_exprs diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index f78a052679..842080f8ae 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -59,7 +59,7 @@ def pytype_to_cpptype(t: str): "axis_literal": None, # TODO: domain? }[t] except KeyError: - raise TypeError(f"Unsupported type '{t}'") from None + raise TypeError(f"Unsupported type '{t}'.") from None _vertical_dimension = "gtfn::unstructured::dim::vertical" @@ -83,7 +83,7 @@ def _get_gridtype(closures: list[itir.StencilClosure]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found StencilClosures with more than one GridType: {grid_types}. This is currently not supported." + f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() @@ -109,7 +109,7 @@ def _collect_dimensions_from_domain( offset_definitions[dim_name] = TagDefinition(name=Sym(id=dim_name)) elif domain.fun == itir.SymRef(id="unstructured_domain"): if len(domain.args) > 2: - raise ValueError("unstructured_domain must not have more than 2 arguments.") + raise ValueError("Unstructured_domain must not have more than 2 arguments.") if len(domain.args) > 0: horizontal_range = domain.args[0] assert isinstance(horizontal_range, itir.FunCall) @@ -126,7 +126,7 @@ def _collect_dimensions_from_domain( ) else: raise AssertionError( - "Expected either a call to `cartesian_domain` or to `unstructured_domain`." + "Expected either a call to 'cartesian_domain' or to 'unstructured_domain'." ) return offset_definitions @@ -181,7 +181,7 @@ def _collect_offset_definitions( ) else: raise AssertionError( - "Elements of offset provider need to be either `Dimension` or `Connectivity`." + "Elements of offset provider need to be either 'Dimension' or 'Connectivity'." ) return offset_definitions @@ -233,7 +233,7 @@ def apply( fencil_definition = node else: raise TypeError( - f"Expected a `FencilDefinition` or `FencilWithTemporaries`, but got `{type(node).__name__}`." + f"Expected a 'FencilDefinition' or 'FencilWithTemporaries', got '{type(node).__name__}'." ) grid_type = _get_gridtype(fencil_definition.closures) @@ -303,7 +303,7 @@ def _make_domain(self, node: itir.FunCall): isinstance(named_range, itir.FunCall) and named_range.fun == itir.SymRef(id="named_range") ): - raise ValueError("Arguments to `domain` need to be calls to `named_range`.") + raise ValueError("Arguments to 'domain' need to be calls to 'named_range'.") tags.append(self.visit(named_range.args[0])) sizes.append( BinaryExpr( @@ -410,9 +410,9 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs: Any) -> Node: # special handling of applied builtins is handled in `_visit_` return getattr(self, visit_method)(node, **kwargs) elif node.fun.id == "shift": - raise ValueError("unapplied shift call not supported: {node}") + raise ValueError("Unapplied shift call not supported: '{node}'.") elif node.fun.id == "scan": - raise ValueError("scans are only supported at the top level of a stencil closure") + raise ValueError("Scans are only supported at the top level of a stencil closure.") if isinstance(node.fun, itir.FunCall): if node.fun.fun == itir.SymRef(id="shift"): assert len(node.args) == 1 @@ -440,7 +440,7 @@ def _visit_output_argument(self, node: itir.Expr): return self.visit(node) elif isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="make_tuple"): return SidComposite(values=[self._visit_output_argument(v) for v in node.args]) - raise ValueError("Expected `SymRef` or `make_tuple` in output argument.") + raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.") @staticmethod def _bool_from_literal(node: itir.Node): diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index d9f8b36301..95d3d2ca35 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -56,6 +56,62 @@ def kind(self) -> type[ProgramFormatter]: return ProgramFormatter +def _make_arg_filter( + accept_args: None | int | Literal["all"] = "all", +) -> Callable[[tuple[Any, ...]], tuple[Any, ...]]: + match accept_args: + case None: + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return () + + case "all": + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return args + + case int(): + if accept_args < 0: + raise ValueError( + f"Number of accepted arguments cannot be a negative number, got {accept_args}." + ) + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return args[:accept_args] + + case _: + raise ValueError(f"Invalid 'accept_args' value: {accept_args}.") + return arg_filter + + +def _make_kwarg_filter( + accept_kwargs: None | Sequence[str] | Literal["all"] = "all", +) -> Callable[[dict[str, Any]], dict[str, Any]]: + match accept_kwargs: + case None: + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return {} + + case "all": + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return kwargs + + case Sequence(): + if not all(isinstance(a, str) for a in accept_kwargs): + raise ValueError( + f"Provided invalid list of keyword argument names: '{accept_kwargs}'." + ) + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in kwargs.items() if key in accept_kwargs} + + case _: + raise ValueError(f"Invalid 'accept_kwargs' value: {accept_kwargs}") + return kwarg_filter + + def make_program_processor( func: ProgramProcessorCallable[OutputT], kind: type[ProcessorKindT], @@ -80,33 +136,9 @@ def make_program_processor( Raises: ValueError: If the value of `accept_args` or `accept_kwargs` is invalid. """ - args_filter: Callable[[Sequence], Sequence] - if accept_args is None: - args_filter = lambda args: () # noqa: E731 # use def instead of named lambdas - elif accept_args == "all": - args_filter = lambda args: args # noqa: E731 - elif isinstance(accept_args, int): - if accept_args < 0: - raise ValueError( - f"Number of accepted arguments cannot be a negative number ({accept_args})" - ) - args_filter = lambda args: args[:accept_args] # type: ignore[misc] # noqa: E731 - else: - raise ValueError(f"Invalid ({accept_args}) accept_args value") - - filtered_kwargs: Callable[[dict[str, Any]], dict[str, Any]] - if accept_kwargs is None: - filtered_kwargs = lambda kwargs: {} # noqa: E731 # use def instead of named lambdas - elif accept_kwargs == "all": # don't swap with 'isinstance(..., Sequence)' - filtered_kwargs = lambda kwargs: kwargs # noqa: E731 - elif isinstance(accept_kwargs, Sequence): - if not all(isinstance(a, str) for a in accept_kwargs): - raise ValueError(f"Provided invalid list of keyword argument names ({accept_args})") - filtered_kwargs = lambda kwargs: { # noqa: E731 - key: value for key, value in kwargs.items() if key in accept_kwargs # type: ignore[operator] # key in accept_kwargs - } - else: - raise ValueError(f"Invalid ({accept_kwargs}) 'accept_kwargs' value") + args_filter = _make_arg_filter(accept_args) + + filtered_kwargs = _make_kwarg_filter(accept_kwargs) @functools.wraps(func) def _wrapper(program: itir.FencilDefinition, *args, **kwargs) -> OutputT: @@ -195,7 +227,7 @@ def ensure_processor_kind( obj: ProgramProcessor[OutputT, ProcessorKindT], kind: type[ProcessorKindT] ) -> None: if not is_processor_kind(obj, kind): - raise TypeError(f"{obj} is not a {kind.__name__}!") + raise TypeError(f"'{obj}' is not a '{kind.__name__}'.") class ProgramBackend( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index acfa06b456..65f9d9d71a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -148,7 +148,7 @@ def get_stride_args( stride, remainder = divmod(stride_size, value.itemsize) if remainder != 0: raise ValueError( - f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)" + f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) stride_args[str(sym)] = stride @@ -334,7 +334,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: else: def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: - raise RuntimeError("Missing `cupy` dependency for GPU execution.") + raise RuntimeError("Missing 'cupy' dependency for GPU execution.") run_dace_gpu = otf_exec.OTFBackend( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d10a14a1ee..d08476847f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -401,7 +401,7 @@ def builtin_tuple_get( index = node_args[0] if isinstance(index, itir.Literal): return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants") + raise ValueError("Tuple can only be subscripted with compile-time constants.") _GENERAL_BUILTIN_MAPPING: dict[ @@ -640,7 +640,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: elif builtin_name in _GENERAL_BUILTIN_MAPPING: return self._visit_general_builtin(node) else: - raise NotImplementedError(f"{builtin_name} not implemented") + raise NotImplementedError(f"'{builtin_name}' not implemented.") return self._visit_call(node) def _visit_call(self, node: itir.FunCall): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index cb14b89e8a..55717326a3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -32,7 +32,7 @@ def as_dace_type(type_: ts.ScalarType): return dace.float32 elif type_.kind == ts.ScalarKind.FLOAT64: return dace.float64 - raise ValueError(f"scalar type {type_} not supported") + raise ValueError(f"Scalar type '{type_}' not supported.") def filter_neighbor_tables(offset_provider: dict[str, Any]): diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 5d4b450d39..baa45ddc0e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -83,7 +83,7 @@ def extract_connectivity_args( if isinstance(conn, common.Connectivity): if not isinstance(conn, common.NeighborTable): raise NotImplementedError( - "Only `NeighborTable` connectivities implemented at this point." + "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later conn_arg = _ensure_is_on_device(conn.table, device) @@ -92,8 +92,8 @@ def extract_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider `{name}` to be a `Connectivity` or `Dimension`, " - f"but got {type(conn).__name__}." + f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " + f"but got '{type(conn).__name__}'." ) return args diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 564df7fd1a..20fa8bd791 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -75,13 +75,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: match symbol_type: case ts.DeferredType(constraint): if constraint is None: - raise ValueError(f"No type information available for {symbol_type}!") + raise ValueError(f"No type information available for '{symbol_type}'.") elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for {symbol_type}!") + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") return constraint case ts.TypeSpec() as concrete_type: return concrete_type.__class__ - raise ValueError(f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!") + raise ValueError( + f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." + ) def primitive_constituents( @@ -163,7 +165,7 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: return dtype case ts.ScalarType() as dtype: return dtype - raise ValueError(f"Can not unambiguosly extract data type from {symbol_type}!") + raise ValueError(f"Can not unambiguosly extract data type from '{symbol_type}'.") def is_floating_point(symbol_type: ts.TypeSpec) -> bool: @@ -320,7 +322,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: return [] case ts.FieldType(dims): return dims - raise ValueError(f"Can not extract dimensions from {symbol_type}!") + raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") def is_local_field(type_: ts.FieldType) -> bool: @@ -435,7 +437,7 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) return ts.FieldType(dims=dims, dtype=dtype) - raise TypeError("Expected a FieldType or ScalarType.") + raise TypeError("Expected a 'FieldType' or 'ScalarType'.") @functools.singledispatch @@ -446,7 +448,7 @@ def return_type( with_kwargs: dict[str, ts.TypeSpec], ): raise NotImplementedError( - f"Return type deduction of type " f"{type(callable_type).__name__} not implemented." + f"Return type deduction of type " f"'{type(callable_type).__name__}' not implemented." ) @@ -473,7 +475,7 @@ def return_type_field( raise ValueError("Could not deduce return type of invalid remap operation.") from ex if not isinstance(with_args[0], ts.OffsetType): - raise ValueError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") + raise ValueError(f"First argument must be of type '{ts.OffsetType}', got '{with_args[0]}'.") source_dim = with_args[0].source target_dims = with_args[0].target @@ -500,7 +502,7 @@ def canonicalize_arguments( ignore_errors=False, use_signature_ordering=False, ) -> tuple[list, dict]: - raise NotImplementedError(f"Not implemented for type {type(func_type).__name__}.") + raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @canonicalize_arguments.register @@ -526,7 +528,7 @@ def canonicalize_function_arguments( cargs[args_idx] = ckwargs.pop(name) elif not ignore_errors: raise AssertionError( - f"Error canonicalizing function arguments. Got multiple values for argument `{name}`." + f"Error canonicalizing function arguments. Got multiple values for argument '{name}'." ) a, b = set(func_type.kw_only_args.keys()), set(ckwargs.keys()) @@ -534,7 +536,7 @@ def canonicalize_function_arguments( if invalid_kw_args and (not ignore_errors or use_signature_ordering): # this error can not be ignored as otherwise the invariant that no arguments are dropped # is invalidated. - raise AssertionError(f"Invalid keyword arguments {', '.join(invalid_kw_args)}.") + raise AssertionError(f"Invalid keyword arguments '{', '.join(invalid_kw_args)}'.") if use_signature_ordering: ckwargs = {k: ckwargs[k] for k in func_type.kw_only_args.keys() if k in ckwargs} @@ -566,7 +568,7 @@ def structural_function_signature_incompatibilities( if args_idx < len(args): # remove the argument here such that later errors stay comprehensible kwargs.pop(name) - yield f"Got multiple values for argument `{name}`." + yield f"Got multiple values for argument '{name}'." num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args) num_pos_args = len(args) - args.count(UNDEFINED_ARG) @@ -582,17 +584,17 @@ def structural_function_signature_incompatibilities( range(len(func_type.pos_only_args), num_pos_params), func_type.pos_or_kw_args.keys() ): if args[i] is UNDEFINED_ARG: - missing_positional_args.append(f"`{arg_type}`") + missing_positional_args.append(f"'{arg_type}'") if missing_positional_args: yield f"Missing {len(missing_positional_args)} required positional argument{'s' if len(missing_positional_args) != 1 else ''}: {', '.join(missing_positional_args)}" # check for missing or extra keyword arguments kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys()) if len(kw_a_m_b) > 0: - yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} `{'`, `'.join(kw_a_m_b)}`." + yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} '{', '.join(kw_a_m_b)}'." kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys()) if len(kw_b_m_a) > 0: - yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} `{'`, `'.join(kw_b_m_a)}`." + yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} '{', '.join(kw_b_m_a)}'." @functools.singledispatch @@ -604,7 +606,7 @@ def function_signature_incompatibilities( Note that all types must be concrete/complete. """ - raise NotImplementedError(f"Not implemented for type {type(func_type).__name__}.") + raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @function_signature_incompatibilities.register @@ -639,14 +641,14 @@ def function_signature_incompatibilities_func( # noqa: C901 if i < len(func_type.pos_only_args): arg_repr = f"{_number_to_ordinal_number(i+1)} argument" else: - arg_repr = f"argument `{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}`" - yield f"Expected {arg_repr} to be of type `{a_arg}`, but got `{b_arg}`." + arg_repr = f"argument '{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}'" + yield f"Expected {arg_repr} to be of type '{a_arg}', got '{b_arg}'." for kwarg in set(func_type.kw_only_args.keys()) & set(kwargs.keys()): if (a_kwarg := func_type.kw_only_args[kwarg]) != ( b_kwarg := kwargs[kwarg] ) and not is_concretizable(a_kwarg, to_type=b_kwarg): - yield f"Expected keyword argument `{kwarg}` to be of type `{func_type.kw_only_args[kwarg]}`, but got `{kwargs[kwarg]}`." + yield f"Expected keyword argument '{kwarg}' to be of type '{func_type.kw_only_args[kwarg]}', got '{kwargs[kwarg]}'." @function_signature_incompatibilities.register @@ -660,11 +662,11 @@ def function_signature_incompatibilities_field( return if not isinstance(args[0], ts.OffsetType): - yield f"Expected first argument to be of type {ts.OffsetType}, but got {args[0]}." + yield f"Expected first argument to be of type '{ts.OffsetType}', got '{args[0]}'." return if kwargs: - yield f"Got unexpected keyword argument(s) `{'`, `'.join(kwargs.keys())}`." + yield f"Got unexpected keyword argument(s) '{', '.join(kwargs.keys())}'." return source_dim = args[0].source @@ -705,7 +707,7 @@ def accepts_args( """ if not isinstance(callable_type, ts.CallableType): if raise_exception: - raise ValueError(f"Expected a callable type, but got `{callable_type}`.") + raise ValueError(f"Expected a callable type, got '{callable_type}'.") return False errors = function_signature_incompatibilities(callable_type, with_args, with_kwargs) @@ -713,7 +715,7 @@ def accepts_args( error_list = list(errors) if len(error_list) > 0: raise ValueError( - f"Invalid call to function of type `{callable_type}`:\n" + f"Invalid call to function of type '{callable_type}':\n" + ("\n".join([f" - {error}" for error in error_list])) ) return True diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 007a83844c..88a8347fe4 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -37,7 +37,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: try: dt = np.dtype(dtype) except TypeError as err: - raise ValueError(f"Invalid scalar type definition ({dtype})") from err + raise ValueError(f"Invalid scalar type definition ('{dtype}').") from err if dt.shape == () and dt.fields is None: match dt: @@ -54,9 +54,9 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: case np.str_: return ts.ScalarKind.STRING case _: - raise ValueError(f"Impossible to map '{dtype}' value to a ScalarKind") + raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: - raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported") + raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported.") def from_type_hint( @@ -76,7 +76,7 @@ def from_type_hint( type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) except Exception as error: raise ValueError( - f"Type annotation ({type_hint}) has undefined forward references!" + f"Type annotation '{type_hint}' has undefined forward references." ) from error # Annotated @@ -98,50 +98,50 @@ def from_type_hint( case builtins.tuple: if not args: - raise ValueError(f"Tuple annotation ({type_hint}) requires at least one argument!") + raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") if Ellipsis in args: - raise ValueError(f"Unbound tuples ({type_hint}) are not allowed!") + raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args]) case common.Field: if (n_args := len(args)) != 2: - raise ValueError(f"Field type requires two arguments, got {n_args}! ({type_hint})") + raise ValueError(f"Field type requires two arguments, got {n_args}: '{type_hint}'.") dims: Union[Ellipsis, list[common.Dimension]] = [] dim_arg, dtype_arg = args if isinstance(dim_arg, list): for d in dim_arg: if not isinstance(d, common.Dimension): - raise ValueError(f"Invalid field dimension definition '{d}'") + raise ValueError(f"Invalid field dimension definition '{d}'.") dims.append(d) elif dim_arg is Ellipsis: dims = dim_arg else: - raise ValueError(f"Invalid field dimensions '{dim_arg}'") + raise ValueError(f"Invalid field dimensions '{dim_arg}'.") try: dtype = recursive_make_symbol(dtype_arg) except ValueError as error: raise ValueError( - f"Field dtype argument must be a scalar type (got '{dtype_arg}')!" + f"Field dtype argument must be a scalar type (got '{dtype_arg}')." ) from error if not isinstance(dtype, ts.ScalarType) or dtype.kind == ts.ScalarKind.STRING: - raise ValueError("Field dtype argument must be a scalar type (got '{dtype}')!") + raise ValueError("Field dtype argument must be a scalar type (got '{dtype}').") return ts.FieldType(dims=dims, dtype=dtype) case collections.abc.Callable: if not args: - raise ValueError("Not annotated functions are not supported!") + raise ValueError("Unannotated functions are not supported.") try: arg_types, return_type = args args = [recursive_make_symbol(arg) for arg in arg_types] except Exception as error: - raise ValueError(f"Invalid callable annotations in {type_hint}") from error + raise ValueError(f"Invalid callable annotations in '{type_hint}'.") from error kwargs_info = [arg for arg in extra_args if isinstance(arg, xtyping.CallableKwargsInfo)] if len(kwargs_info) != 1: - raise ValueError(f"Invalid callable annotations in {type_hint}") + raise ValueError(f"Invalid callable annotations in '{type_hint}'.") kwargs = { arg: recursive_make_symbol(arg_type) for arg, arg_type in kwargs_info[0].data.items() @@ -155,7 +155,7 @@ def from_type_hint( returns=recursive_make_symbol(return_type), ) - raise ValueError(f"'{type_hint}' type is not supported") + raise ValueError(f"'{type_hint}' type is not supported.") def from_value(value: Any) -> ts.TypeSpec: @@ -178,7 +178,7 @@ def from_value(value: Any) -> ts.TypeSpec: break if not symbol_type: raise ValueError( - f"Value `{value}` is out of range to be representable as `INT32` or `INT64`." + f"Value '{value}' is out of range to be representable as 'INT32' or 'INT64'." ) return candidate_type elif isinstance(value, common.Dimension): @@ -200,4 +200,4 @@ def from_value(value: Any) -> ts.TypeSpec: if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type else: - raise ValueError(f"Impossible to map '{value}' value to a Symbol") + raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.") diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index b1e26b40cb..6217d3c782 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -127,7 +127,7 @@ class ConstInitializer(DataInitializer): def __init__(self, value: ScalarValue): if not core_defs.is_scalar_type(value): raise ValueError( - "`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead." + "'ConstInitializer' can not be used with non-scalars. Use 'Case.as_field' instead." ) self.value = value @@ -162,7 +162,7 @@ class IndexInitializer(DataInitializer): @property def scalar_value(self) -> ScalarValue: - raise AttributeError("`scalar_value` not supported in `IndexInitializer`.") + raise AttributeError("'scalar_value' not supported in 'IndexInitializer'.") def field( self, @@ -172,7 +172,7 @@ def field( ) -> FieldValue: if len(sizes) > 1: raise ValueError( - f"`IndexInitializer` only supports fields with a single `Dimension`, got {sizes}." + f"'IndexInitializer' only supports fields with a single 'Dimension', got {sizes}." ) n_data = list(sizes.values())[0] return constructors.as_field( @@ -244,7 +244,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.partial(*args, **kwargs) def __getattr__(self, name: str) -> Any: - raise AttributeError(f"No setter for argument {name}.") + raise AttributeError(f"No setter for argument '{name}'.") @typing.overload @@ -323,7 +323,7 @@ class NewBuilder(Builder): if 0 < len(args) <= 1 and args[0] is not None: return make_builder_inner(args[0]) if len(args) > 1: - raise ValueError(f"make_builder takes only one positional argument, {len(args)} received!") + raise ValueError(f"make_builder takes only one positional argument, {len(args)} received.") return make_builder_inner @@ -533,7 +533,7 @@ def _allocate_from_type( ) case _: raise TypeError( - f"Can not allocate for type {arg_type} with initializer {strategy or 'default'}" + f"Can not allocate for type '{arg_type}' with initializer '{strategy or 'default'}'." ) @@ -542,7 +542,7 @@ def get_param_types( ) -> dict[str, ts.TypeSpec]: if fieldview_prog.definition is None: raise ValueError( - f"test cases do not support {type(fieldview_prog)} with empty .definition attribute (as you would get from .as_program())!" + f"test cases do not support '{type(fieldview_prog)}' with empty .definition attribute (as you would get from .as_program())." ) annotations = xtyping.get_type_hints(fieldview_prog.definition) return { @@ -559,7 +559,7 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> case ts.TupleType(types): return sum([get_param_size(t, sizes=sizes) for t in types]) case _: - raise TypeError(f"Can not get size for parameter of type {param_type}") + raise TypeError(f"Can not get size for parameter of type '{param_type}'.") def extend_sizes( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index f8a3f6a975..e25576ebde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,7 +22,6 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.runners import gtfn, roundtrip try: @@ -39,7 +38,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" - raise ValueError("No backend selected! Backend selection is mandatory in tests.") + raise ValueError("No backend selected. Backend selection is mandatory in tests.") OPTIONAL_PROCESSORS = [] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 6293ff76bd..b41696a36b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -226,7 +226,7 @@ def testee( def test_scan_wrong_return_type(cartesian_case): with pytest.raises( errors.DSLError, - match=(r"Argument `init` to scan operator `testee_scan` must have same type as its return"), + match=(r"Argument 'init' to scan operator 'testee_scan' must have same type as its return"), ): @scan_operator(axis=KDim, forward=True, init=0) @@ -245,7 +245,7 @@ def test_scan_wrong_state_type(cartesian_case): with pytest.raises( errors.DSLError, match=( - r"Argument `init` to scan operator `testee_scan` must have same type as `state` argument" + r"Argument 'init' to scan operator 'testee_scan' must have same type as 'state' argument" ), ): @@ -276,7 +276,7 @@ def program_bound_args(arg1: bool, arg2: bool, out: cases.IField): def test_bind_invalid_arg(cartesian_case, bound_args_testee): with pytest.raises( - TypeError, match="Keyword argument `inexistent_arg` is not a valid program parameter." + TypeError, match="Keyword argument 'inexistent_arg' is not a valid program parameter." ): bound_args_testee.with_bound_args(inexistent_arg=1) @@ -306,7 +306,7 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te assert ( re.search( - "Parameter `arg2` already set as a bound argument.", exc_info.value.__cause__.args[0] + "Parameter 'arg2' already set as a bound argument.", exc_info.value.__cause__.args[0] ) is not None ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 51f853d41d..a08931628b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1188,7 +1188,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): with pytest.raises( errors.DSLError, - match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), + match=(r"Too many values to unpack \(expected 3\)."), ): @gtx.field_operator(backend=cartesian_case.backend) @@ -1197,8 +1197,10 @@ def _star_unpack() -> tuple[int32, float64, int32]: return a, b, c -def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises(errors.DSLError, match=(r"Assignment value must be of type tuple!")): +def test_tuple_unpacking_too_few_values(cartesian_case): + with pytest.raises( + errors.DSLError, match=(r"Assignment value must be of type tuple, got 'int32'.") + ): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 8cfcff160c..167ccbb0a5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -57,7 +57,7 @@ def make_builtin_field_operator(builtin_name: str): "return": cases.IFloatField, } else: - raise AssertionError(f"Unknown builtin `{builtin_name}`") + raise AssertionError(f"Unknown builtin '{builtin_name}'.") closure_vars = {"IDim": IDim, builtin_name: getattr(fbuiltins, builtin_name)} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index c2ab43773f..f5bf453a09 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -147,9 +147,9 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: def test_unary_not(cartesian_case): pytest.xfail( - "We accidentally supported `not` on fields. This is wrong, we should raise an error." + "We accidentally supported 'not' on fields. This is wrong, we should raise an error." ) - with pytest.raises: # TODO `not` on a field should be illegal + with pytest.raises: # TODO 'not' on a field should be illegal @gtx.field_operator def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 4c0613a33c..c86881ab7c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -228,8 +228,8 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program(inp, out, offset_provider={}) msgs = [ - r"- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," - r" but got `Field\[\[JDim\], float64\]`.", + r"- Expected argument 'in_field' to be of type 'Field\[\[IDim], float64\]'," + r" got 'Field\[\[JDim\], float64\]'.", ] for msg in msgs: assert re.search(msg, exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 84b480a23d..af06da3e29 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -334,7 +334,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(errors.DSLError, match="Condition for `if` must be scalar."): + with pytest.raises(errors.DSLError, match="Condition for 'if' must be scalar"): @field_operator def if_non_scalar_condition( @@ -347,7 +347,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises(errors.DSLError, match="Condition for `if` must be of boolean type."): + with pytest.raises(errors.DSLError, match="Condition for 'if' must be of boolean type"): @field_operator def if_non_boolean_condition( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index d1a5f24f79..2174871f89 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -88,7 +88,7 @@ def type_info_cases() -> list[tuple[Optional[ts.TypeSpec], dict]]: def callable_type_info_cases(): # reuse all the other test cases not_callable = [ - (symbol_type, [], {}, [r"Expected a callable type, but got "], None) + (symbol_type, [], {}, [r"Expected a callable type, got "], None) for symbol_type, attributes in type_info_cases() if not isinstance(symbol_type, ts.CallableType) ] @@ -165,7 +165,7 @@ def callable_type_info_cases(): nullary_func_type, [], {"foo": bool_type}, - [r"Got unexpected keyword argument `foo`."], + [r"Got unexpected keyword argument 'foo'."], None, ), ( @@ -180,7 +180,7 @@ def callable_type_info_cases(): unary_func_type, [float_type], {}, - [r"Expected 1st argument to be of type `bool`, but got `float64`."], + [r"Expected 1st argument to be of type 'bool', got 'float64'."], None, ), ( @@ -188,7 +188,7 @@ def callable_type_info_cases(): [], {}, [ - r"Missing 1 required positional argument: `foo`", + r"Missing 1 required positional argument: 'foo'", r"Function takes 1 positional argument, but 0 were given.", ], None, @@ -199,31 +199,31 @@ def callable_type_info_cases(): kw_or_pos_arg_func_type, [], {"foo": float_type}, - [r"Expected argument `foo` to be of type `bool`, but got `float64`."], + [r"Expected argument 'foo' to be of type 'bool', got 'float64'."], None, ), ( kw_or_pos_arg_func_type, [], {"bar": bool_type}, - [r"Got unexpected keyword argument `bar`."], + [r"Got unexpected keyword argument 'bar'."], None, ), # function with keyword-only argument - (kw_only_arg_func_type, [], {}, [r"Missing required keyword argument `foo`."], None), + (kw_only_arg_func_type, [], {}, [r"Missing required keyword argument 'foo'."], None), (kw_only_arg_func_type, [], {"foo": bool_type}, [], ts.VoidType()), ( kw_only_arg_func_type, [], {"foo": float_type}, - [r"Expected keyword argument `foo` to be of type `bool`, but got `float64`."], + [r"Expected keyword argument 'foo' to be of type 'bool', got 'float64'."], None, ), ( kw_only_arg_func_type, [], {"bar": bool_type}, - [r"Got unexpected keyword argument `bar`."], + [r"Got unexpected keyword argument 'bar'."], None, ), # function with positional, keyword-or-positional, and keyword-only argument @@ -232,9 +232,9 @@ def callable_type_info_cases(): [], {}, [ - r"Missing 1 required positional argument: `foo`", + r"Missing 1 required positional argument: 'foo'", r"Function takes 2 positional arguments, but 0 were given.", - r"Missing required keyword argument `bar`", + r"Missing required keyword argument 'bar'", ], None, ), @@ -244,7 +244,7 @@ def callable_type_info_cases(): {}, [ r"Function takes 2 positional arguments, but 1 were given.", - r"Missing required keyword argument `bar`", + r"Missing required keyword argument 'bar'", ], None, ), @@ -252,14 +252,14 @@ def callable_type_info_cases(): pos_arg_and_kw_or_pos_arg_and_kw_only_arg_func_type, [bool_type], {"foo": int_type}, - [r"Missing required keyword argument `bar`"], + [r"Missing required keyword argument 'bar'"], None, ), ( pos_arg_and_kw_or_pos_arg_and_kw_only_arg_func_type, [bool_type], {"foo": int_type}, - [r"Missing required keyword argument `bar`"], + [r"Missing required keyword argument 'bar'"], None, ), ( @@ -274,9 +274,9 @@ def callable_type_info_cases(): [int_type], {"bar": bool_type, "foo": bool_type}, [ - r"Expected 1st argument to be of type `bool`, but got `int64`", - r"Expected argument `foo` to be of type `int64`, but got `bool`", - r"Expected keyword argument `bar` to be of type `float64`, but got `bool`", + r"Expected 1st argument to be of type 'bool', got 'int64'", + r"Expected argument 'foo' to be of type 'int64', got 'bool'", + r"Expected keyword argument 'bar' to be of type 'float64', got 'bool'", ], None, ), @@ -299,7 +299,7 @@ def callable_type_info_cases(): [ts.TupleType(types=[float_type, field_type])], {}, [ - r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" + r"Expected 1st argument to be of type 'tuple\[bool, Field\[\[I\], float64\]\]', got 'tuple\[float64, Field\[\[I\], float64\]\]'" ], ts.VoidType(), ), @@ -308,7 +308,7 @@ def callable_type_info_cases(): [int_type], {}, [ - r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" + r"Expected 1st argument to be of type 'tuple\[bool, Field\[\[I\], float64\]\]', got 'int64'" ], ts.VoidType(), ), @@ -330,8 +330,8 @@ def callable_type_info_cases(): ], {}, [ - r"Expected argument `a` to be of type `Field\[\[K\], int64\]`, but got `Field\[\[K\], float64\]`", - r"Expected argument `b` to be of type `Field\[\[K\], int64\]`, but got `Field\[\[K\], float64\]`", + r"Expected argument 'a' to be of type 'Field\[\[K\], int64\]', got 'Field\[\[K\], float64\]'", + r"Expected argument 'b' to be of type 'Field\[\[K\], int64\]', got 'Field\[\[K\], float64\]'", ], ts.FieldType(dims=[KDim], dtype=float_type), ), @@ -393,8 +393,8 @@ def callable_type_info_cases(): ], {}, [ - r"Expected argument `a` to be of type `tuple\[Field\[\[I, J, K\], int64\], " - r"Field\[\[\.\.\.\], int64\]\]`, but got `tuple\[Field\[\[I, J, K\], int64\]\]`." + r"Expected argument 'a' to be of type 'tuple\[Field\[\[I, J, K\], int64\], " + r"Field\[\[\.\.\.\], int64\]\]', got 'tuple\[Field\[\[I, J, K\], int64\]\]'." ], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), @@ -491,7 +491,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): with pytest.raises( errors.DSLError, - match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), + match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'."), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -507,7 +507,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): with pytest.raises( errors.DSLError, match=( - r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." + r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." ), ): _ = FieldOperatorParser.apply_to_function(nonmatching) @@ -519,7 +519,7 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), + match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -530,7 +530,7 @@ def sign_bool(a: Field[[TDim], bool]): with pytest.raises( errors.DSLError, - match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", + match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -541,7 +541,7 @@ def not_int(a: Field[[TDim], int64]): with pytest.raises( errors.DSLError, - match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", + match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -613,7 +613,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: with pytest.raises( errors.DSLError, - match=(r"Could not promote `float32` and `float64` to common type in call to +."), + match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -643,7 +643,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): with pytest.raises( errors.DSLError, - match=r"Expected broadcast dimension is missing", + match=r"expected broadcast dimension\(s\) \'.*\' missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -658,7 +658,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): with pytest.raises( errors.DSLError, - match=r"Expected all broadcast dimensions to be of type Dimension.", + match=r"expected all broadcast dimensions to be of type 'Dimension'.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -778,7 +778,7 @@ def simple_astype(a: Field[[TDim], float64]): with pytest.raises( errors.DSLError, - match=r"Invalid call to `astype`. Second argument must be a scalar type, but got.", + match=r"Invalid call to 'astype': second argument must be a scalar type, got.", ): _ = FieldOperatorParser.apply_to_function(simple_astype) @@ -806,7 +806,7 @@ def modulo_floats(inp: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=r"Type float64 can not be used in operator `%`", + match=r"Type 'float64' can not be used in operator '%'", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -844,6 +844,6 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): with pytest.raises( errors.DSLError, - match=f"Excepted integer for offset field dtype", + match=f"expected integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c0d565bbf4..d3f3f35699 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -126,7 +126,7 @@ def fenimpl(size, arg0, arg1, arg2, out): closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) else: - raise AssertionError("Add overload") + raise AssertionError("Add overload.") return run_processor(fenimpl, processor, out.shape[0], *inps, out) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 6f91557e46..4177a5aeee 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -109,7 +109,7 @@ def run_processor( elif ppi.is_processor_kind(processor, ppi.ProgramFormatter): print(program.format_itir(*args, formatter=processor, **kwargs)) else: - raise TypeError(f"program processor kind not recognized: {processor}!") + raise TypeError(f"program processor kind not recognized: '{processor}'.") @dataclasses.dataclass diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 2b78eb9114..1a38e5245e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -631,5 +631,5 @@ def test_setitem_wrong_domain(): np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) ) - with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): field[(1, slice(None))] = value_incompatible diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index e5bbed19fd..96ecc19c0b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -88,7 +88,7 @@ def mistyped(inp: gtx.Field): with pytest.raises( ValueError, - match="Field type requires two arguments, got 0!", + match="Field type requires two arguments, got 0.", ): _ = FieldOperatorParser.apply_to_function(mistyped) @@ -245,7 +245,7 @@ def conditional_wrong_mask_type( ) -> gtx.Field[[TDim], float64]: return where(a, a, a) - msg = r"Expected a field with dtype `bool`." + msg = r"expected a field with dtype 'bool'" with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -269,7 +269,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: gtx.Field[[], bool]): return 1 if cond else 2 - with pytest.raises(errors.DSLError, match=r"should be .* `bool`"): + with pytest.raises(errors.DSLError, match=r"should be .* 'bool'"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -288,7 +288,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> gtx.Field[[], float]: return 1.0 - with pytest.raises(errors.DSLError, match=r"Expected `float.*`"): + with pytest.raises(errors.DSLError, match=r"expected 'float.*'"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -395,8 +395,6 @@ def zero_dims_ternary( ): return a if cond == 1 else b - msg = r"Incompatible datatypes in operator `==`" - with pytest.raises(errors.DSLError) as exc_info: + msg = r"Incompatible datatypes in operator '=='" + with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) - - assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 1d1a1efad4..cca05f9917 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -113,7 +113,7 @@ def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): with pytest.raises( errors.DSLError, - match=(r"Undeclared or untyped symbol `out_field`."), + match=(r"Undeclared or untyped symbol 'out_field'."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -165,10 +165,10 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) - assert exc_info.match("Invalid call to `domain_format_1`") + assert exc_info.match("Invalid call to 'domain_format_1'") assert ( - re.search("Only Dictionaries allowed in domain", exc_info.value.__cause__.args[0]) + re.search("Only Dictionaries allowed in 'domain'", exc_info.value.__cause__.args[0]) is not None ) @@ -184,7 +184,7 @@ def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) - assert exc_info.match("Invalid call to `domain_format_2`") + assert exc_info.match("Invalid call to 'domain_format_2'") assert ( re.search("Only 2 values allowed in domain range", exc_info.value.__cause__.args[0]) @@ -203,10 +203,10 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) - assert exc_info.match("Invalid call to `domain_format_3`") + assert exc_info.match("Invalid call to 'domain_format_3'") assert ( - re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\ 'out'", exc_info.value.__cause__.args[0]) is not None ) @@ -224,7 +224,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) - assert exc_info.match("Invalid call to `domain_format_4`") + assert exc_info.match("Invalid call to 'domain_format_4'") assert ( re.search("Either only domain or slicing allowed", exc_info.value.__cause__.args[0]) @@ -243,7 +243,7 @@ def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) - assert exc_info.match("Invalid call to `domain_format_5`") + assert exc_info.match("Invalid call to 'domain_format_5'") assert ( re.search("Only integer values allowed in domain range", exc_info.value.__cause__.args[0]) @@ -262,6 +262,6 @@ def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) - assert exc_info.match("Invalid call to `domain_format_6`") + assert exc_info.match("Invalid call to 'domain_format_6'") assert re.search("Empty domain not allowed.", exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index c4fe30c596..a1a7b79cec 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -177,7 +177,7 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): grid_type=gtx.GridType.CARTESIAN, ) - assert exc_info.match("Invalid call to `identity`") + assert exc_info.match("Invalid call to 'identity'") # TODO(tehrengruber): re-enable again when call signature check doesn't return # immediately after missing `out` argument # assert ( @@ -187,6 +187,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): # is not None # ) assert ( - re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) is not None ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 232995be58..73ad24f42b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -56,7 +56,7 @@ def test_embedded_error_on_wrong_domain(): 1, ), ) - with pytest.raises(RuntimeError, match="expected `UnstructuredDomain`"): + with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): foo[dom]( gtx.as_field([I], np.zeros((1,))), out=out, diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index 05e982cf0c..1ba35da7c6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -74,12 +74,12 @@ def test_undecorated_formatter_function_is_not_recognized(): def undecorated_formatter(fencil: itir.FencilDefinition, *args, **kwargs) -> str: return "" - with pytest.raises(TypeError, match="is not a ProgramFormatter"): + with pytest.raises(TypeError, match="is not a 'ProgramFormatter'"): ensure_processor_kind(undecorated_formatter, ProgramFormatter) def test_wrong_processor_type_is_caught_at_runtime(dummy_formatter): - with pytest.raises(TypeError, match="is not a ProgramExecutor"): + with pytest.raises(TypeError, match="is not a 'ProgramExecutor'"): ensure_processor_kind(dummy_formatter, ProgramExecutor) diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index 456654c1d0..599bea75e7 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -108,7 +108,7 @@ def test_get_allocator(): with pytest.raises( TypeError, - match=f"Object {invalid_obj} is neither a field allocator nor a field allocator factory", + match=f"Object '{invalid_obj}' is neither a field allocator nor a field allocator factory", ): next_allocators.get_allocator(invalid_obj, strict=True) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index da63536953..bafabfb56e 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -96,7 +96,7 @@ def test_unit_range_slice_error(rng): def test_unit_range_set_intersection(rng): with pytest.raises( - NotImplementedError, match="Can only find the intersection between UnitRange instances." + NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." ): rng & {1, 5} diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index e8b070f0c0..8d95c9951f 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -139,7 +139,7 @@ def test_as_field_origin(): def test_field_wrong_dims(): with pytest.raises( ValueError, - match=(r"Cannot construct `Field` from array of shape"), + match=(r"Cannot construct 'Field' from array of shape"), ): gtx.as_field([I, J], np.random.rand(sizes[I]).astype(gtx.float32)) @@ -147,7 +147,7 @@ def test_field_wrong_dims(): def test_field_wrong_domain(): with pytest.raises( ValueError, - match=(r"Cannot construct `Field` from array of shape"), + match=(r"Cannot construct 'Field' from array of shape"), ): domain = common.Domain( dims=(I, J), diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index d281f5cd90..0a0b747a28 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -158,7 +158,7 @@ def test_invalid_symbol_types(): type_translation.from_type_hint(common.Field[[IDim], None]) # Functions - with pytest.raises(ValueError, match="Not annotated functions are not supported"): + with pytest.raises(ValueError, match="Unannotated functions are not supported"): type_translation.from_type_hint(typing.Callable) with pytest.raises(ValueError, match="Invalid callable annotations"): From 0d66829d8c68b89a620c87fa3fbc8f5b64287d27 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Dec 2023 11:45:16 +0100 Subject: [PATCH 59/67] docs[next]: Partially fix Quickstart Guide (#1390) Changes to the quickstart guide to use `field.asnumpy()` (introduced in #1366) instead of `np.asarray(field)`. The quickstart guide is still broken though since the embedded backend (used by default) does not support skip neighbors connectivities. --- docs/user/next/QuickstartGuide.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 1ae1db4d92..dc70f804fd 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -102,7 +102,7 @@ You can call field operators from [programs](#Programs), other field operators, result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) add(a, b, out=result, offset_provider={}) -print("{} + {} = {} ± {}".format(a_value, b_value, np.average(np.asarray(result)), np.std(np.asarray(result)))) +print("{} + {} = {} ± {}".format(a_value, b_value, np.average(result.asnumpy()), np.std(result.asnumpy()))) ``` #### Programs @@ -128,7 +128,7 @@ You can execute the program by simply calling it: result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) run_add(a, b, result, offset_provider={}) -print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(np.asarray(result)), np.std(np.asarray(result)))) +print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(result.asnumpy()), np.std(result.asnumpy()))) ``` #### Composing field operators and programs @@ -256,7 +256,7 @@ def run_nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64], out : g run_nearest_cell_to_edge(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print("0th adjacent cell's value: {}".format(np.asarray(edge_values))) +print("0th adjacent cell's value: {}".format(edge_values.asnumpy())) ``` Running the above snippet results in the following edge field: @@ -283,7 +283,7 @@ def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Fiel run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print("sum of adjacent cells: {}".format(np.asarray(edge_values))) +print("sum of adjacent cells: {}".format(edge_values.asnumpy())) ``` For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: @@ -317,7 +317,7 @@ def conditional(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, K return where(mask, a, b) conditional(mask, a, b, out=result_where, offset_provider={}) -print("where return: {}".format(np.asarray(result_where))) +print("where return: {}".format(result_where.asnumpy())) ``` **Tuple implementation:** @@ -340,7 +340,7 @@ result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDi _conditional_tuple(mask, a, b, out=(result_1, result_2)) conditional_tuple(mask, a, b, result_1, result_2, offset_provider={}) -print("where tuple return: {}".format((np.asarray(result_1), np.asarray(result_2)))) +print("where tuple return: {}".format((result_1.asnumpy(), result_2.asnumpy()))) ``` The `where` builtin also allows for nesting of tuples. In this scenario, it will first perform an unrolling: @@ -375,7 +375,7 @@ def conditional_tuple_nested( _conditional_tuple_nested(mask, a, b, c, d, out=((result_1, result_2), (result_2, result_1))) conditional_tuple_nested(mask, a, b, c, d, result_1, result_2, offset_provider={}) -print("where nested tuple return: {}".format(((np.asarray(result_1), np.asarray(result_2)), (np.asarray(result_2), np.asarray(result_1))))) +print("where nested tuple return: {}".format(((result_1.asnumpy(), result_2.asnumpy()), (result_2.asnumpy(), result_1.asnumpy())))) ``` #### Implementing the pseudo-laplacian @@ -447,7 +447,7 @@ run_pseudo_laplacian(cell_values, result_pseudo_lap, offset_provider={"E2C": E2C_offset_provider, "C2E": C2E_offset_provider}) -print("pseudo-laplacian: {}".format(np.asarray(result_pseudo_lap))) +print("pseudo-laplacian: {}".format(result_pseudo_lap.asnumpy())) ``` As a closure, here is an example of chaining field operators, which is very simple to do when working with fields. The field operator below executes the pseudo-laplacian, and then calls the pseudo-laplacian on the result of the first, in effect, calculating the laplacian of a laplacian. From cdcd6537bbc05b050a25ae6abea5b69490ed87db Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 18 Dec 2023 08:52:29 +0100 Subject: [PATCH 60/67] feat[next]: Add missing UnitRange comparison functions (#1363) - Introduce a better Infinity - Make UnitRange Generic to express finite, infinite, left-finite, right-finite properly. - Remove `Set` from UnitRange --- src/gt4py/next/common.py | 228 ++++++++++++------ src/gt4py/next/embedded/common.py | 1 + src/gt4py/next/embedded/nd_array_field.py | 25 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/embedded.py | 2 +- .../runners/dace_iterator/__init__.py | 44 ++-- .../embedded_tests/test_nd_array_field.py | 9 +- tests/next_tests/unit_tests/test_common.py | 138 ++++++++--- 8 files changed, 305 insertions(+), 144 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 3e1fe52f31..29d606ccc0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,8 @@ import enum import functools import numbers -import sys import types -from collections.abc import Mapping, Sequence, Set +from collections.abc import Mapping, Sequence import numpy as np import numpy.typing as npt @@ -33,10 +32,12 @@ Any, Callable, ClassVar, + Generic, Never, Optional, ParamSpec, Protocol, + Self, TypeAlias, TypeGuard, TypeVar, @@ -52,16 +53,6 @@ DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) -class Infinity(int): - @classmethod - def positive(cls) -> Infinity: - return cls(sys.maxsize) - - @classmethod - def negative(cls) -> Infinity: - return cls(-sys.maxsize) - - Tag: TypeAlias = str @@ -84,31 +75,86 @@ def __str__(self): return f"{self.value}[{self.kind}]" +class Infinity(enum.Enum): + """Describes an unbounded `UnitRange`.""" + + NEGATIVE = enum.auto() + POSITIVE = enum.auto() + + def __add__(self, _: int) -> Self: + return self + + __radd__ = __add__ + + def __sub__(self, _: int) -> Self: + return self + + __rsub__ = __sub__ + + def __le__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE or other is self.POSITIVE + + def __lt__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE and other is not self + + def __ge__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE or other is self.NEGATIVE + + def __gt__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE and other is not self + + +def _as_int(v: core_defs.IntegralScalar | Infinity) -> int | Infinity: + return v if isinstance(v, Infinity) else int(v) + + +_Left = TypeVar("_Left", int, Infinity) +_Right = TypeVar("_Right", int, Infinity) + + @dataclasses.dataclass(frozen=True, init=False) -class UnitRange(Sequence[int], Set[int]): +class UnitRange(Sequence[int], Generic[_Left, _Right]): """Range from `start` to `stop` with step size one.""" - start: int - stop: int + start: _Left + stop: _Right - def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None: + def __init__( + self, start: core_defs.IntegralScalar | Infinity, stop: core_defs.IntegralScalar | Infinity + ) -> None: if start < stop: - object.__setattr__(self, "start", int(start)) - object.__setattr__(self, "stop", int(stop)) + object.__setattr__(self, "start", _as_int(start)) + object.__setattr__(self, "stop", _as_int(stop)) else: # make UnitRange(0,0) the single empty UnitRange object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) - # TODO: the whole infinity idea and implementation is broken and should be replaced @classmethod - def infinity(cls) -> UnitRange: - return cls(Infinity.negative(), Infinity.positive()) + def infinite( + cls, + ) -> UnitRange: + return cls(Infinity.NEGATIVE, Infinity.POSITIVE) def __len__(self) -> int: - if Infinity.positive() in (abs(self.start), abs(self.stop)): - return Infinity.positive() - return max(0, self.stop - self.start) + if UnitRange.is_finite(self): + return max(0, self.stop - self.start) + raise ValueError("Cannot compute length of open 'UnitRange'.") + + @classmethod + def is_finite(cls, obj: UnitRange) -> TypeGuard[FiniteUnitRange]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE and obj.stop is not Infinity.POSITIVE + + @classmethod + def is_right_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[_Left, int]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.stop is not Infinity.POSITIVE + + @classmethod + def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @@ -122,6 +168,7 @@ def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unuse ... def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused + assert UnitRange.is_finite(self) if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: @@ -138,61 +185,60 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re else: raise IndexError("'UnitRange' index out of range") - def __and__(self, other: Set[int]) -> UnitRange: - if isinstance(other, UnitRange): - start = max(self.start, other.start) - stop = min(self.stop, other.stop) - return UnitRange(start, stop) - else: - raise NotImplementedError( - "Can only find the intersection between 'UnitRange' instances." - ) + def __and__(self, other: UnitRange) -> UnitRange: + return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) + + def __contains__(self, value: Any) -> bool: + return ( + isinstance(value, core_defs.INTEGRAL_TYPES) + and value >= self.start + and value < self.stop + ) + + def __le__(self, other: UnitRange) -> bool: + return self.start >= other.start and self.stop <= other.stop + + def __lt__(self, other: UnitRange) -> bool: + return (self.start > other.start and self.stop <= other.stop) or ( + self.start >= other.start and self.stop < other.stop + ) + + def __ge__(self, other: UnitRange) -> bool: + return self.start <= other.start and self.stop >= other.stop - def __le__(self, other: Set[int]): + def __gt__(self, other: UnitRange) -> bool: + return (self.start < other.start and self.stop >= other.stop) or ( + self.start <= other.start and self.stop > other.stop + ) + + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitRange): - return self.start >= other.start and self.stop <= other.stop - elif len(self) == Infinity.positive(): - return False - else: - return Set.__le__(self, other) - - def __add__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.positive(): - return UnitRange.infinity() - elif other == Infinity.negative(): - return UnitRange(0, 0) - return UnitRange( - *( - s if s in [Infinity.negative(), Infinity.positive()] else s + other - for s in (self.start, self.stop) - ) - ) - else: - raise NotImplementedError("Can only compute union with 'int' instances.") - - def __sub__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.negative(): - return self + Infinity.positive() - elif other == Infinity.positive(): - return self + Infinity.negative() - else: - return self + (-other) + return self.start == other.start and self.stop == other.stop else: - raise NotImplementedError("Can only compute substraction with 'int' instances.") + return False + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) - __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented + def __add__(self, other: int) -> UnitRange: + return UnitRange(self.start + other, self.stop + other) + + def __sub__(self, other: int) -> UnitRange: + return UnitRange(self.start - other, self.stop - other) def __str__(self) -> str: return f"({self.start}:{self.stop})" +FiniteUnitRange: TypeAlias = UnitRange[int, int] + + RangeLike: TypeAlias = ( UnitRange | range | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] | core_defs.IntegralScalar + | None ) @@ -207,18 +253,23 @@ def unit_range(r: RangeLike) -> UnitRange: # once the related mypy bug (#16358) gets fixed if ( isinstance(r, tuple) - and isinstance(r[0], core_defs.INTEGRAL_TYPES) - and isinstance(r[1], core_defs.INTEGRAL_TYPES) + and (isinstance(r[0], core_defs.INTEGRAL_TYPES) or r[0] in (None, Infinity.NEGATIVE)) + and (isinstance(r[1], core_defs.INTEGRAL_TYPES) or r[1] in (None, Infinity.POSITIVE)) ): - return UnitRange(r[0], r[1]) + start = r[0] if r[0] is not None else Infinity.NEGATIVE + stop = r[1] if r[1] is not None else Infinity.POSITIVE + return UnitRange(start, stop) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) + if r is None: + return UnitRange.infinite() raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple +FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement @@ -245,6 +296,10 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: ) +def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: + return UnitRange.is_finite(v[1]) + + def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) @@ -283,18 +338,27 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return (v[0], unit_range(v[1])) +_Rng = TypeVar( + "_Rng", + UnitRange[int, int], + UnitRange[Infinity, int], + UnitRange[int, Infinity], + UnitRange[Infinity, Infinity], +) + + @dataclasses.dataclass(frozen=True, init=False) -class Domain(Sequence[NamedRange]): +class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" dims: tuple[Dimension, ...] - ranges: tuple[UnitRange, ...] + ranges: tuple[_Rng, ...] def __init__( self, - *args: NamedRange, + *args: tuple[Dimension, _Rng], dims: Optional[Sequence[Dimension]] = None, - ranges: Optional[Sequence[UnitRange]] = None, + ranges: Optional[Sequence[_Rng]] = None, ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: @@ -343,16 +407,23 @@ def ndim(self) -> int: def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) + @classmethod + def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return all(UnitRange.is_finite(rng) for rng in obj.ranges) + @overload - def __getitem__(self, index: int) -> NamedRange: + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload - def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused + def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused ... @overload - def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused + def __getitem__( # noqa: F811 # redefine unused + self, index: Dimension + ) -> tuple[Dimension, _Rng]: ... def __getitem__( # noqa: F811 # redefine unused @@ -434,6 +505,9 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) +FiniteDomain: TypeAlias = Domain[FiniteUnitRange] + + DomainLike: TypeAlias = ( Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] ) # `Domain` is `Sequence[NamedRange]` and therefore a subset @@ -484,7 +558,7 @@ def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinite() for d in broadcast_dims ) @@ -847,7 +921,7 @@ def asnumpy(self) -> Never: @functools.cached_property def domain(self) -> Domain: - return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),)) + return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) @property def __gt_dims__(self) -> tuple[Dimension, ...]: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 87e0800a10..94efe4d61d 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -58,6 +58,7 @@ def _relative_sub_domain( else: # not in new domain assert common.is_int_index(idx) + assert common.UnitRange.is_finite(rng) new_index = (rng.start if idx >= 0 else rng.stop) + idx if new_index < rng.start or new_index >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fbfe64ac42..8bd2673db9 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -113,6 +113,7 @@ def __gt_dims__(self) -> tuple[common.Dimension, ...]: @property def __gt_origin__(self) -> tuple[int, ...]: + assert common.Domain.is_finite(self._domain) return tuple(-r.start for _, r in self._domain) @property @@ -386,6 +387,7 @@ def inverse_image( assert isinstance(image_range, common.UnitRange) + assert common.UnitRange.is_finite(image_range) restricted_mask = (self._ndarray >= image_range.start) & ( self._ndarray < image_range.stop ) @@ -566,9 +568,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - named_ranges.append( - (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) - ) + named_ranges.append((dim, common.UnitRange.infinite())) return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -638,14 +638,19 @@ def _compute_slice( ValueError: If `new_rng` is not an integer or a UnitRange. """ if isinstance(rng, common.UnitRange): - if domain.ranges[pos] == common.UnitRange.infinity(): - return slice(None) - else: - return slice( - rng.start - domain.ranges[pos].start, - rng.stop - domain.ranges[pos].start, - ) + start = ( + rng.start - domain.ranges[pos].start + if common.UnitRange.is_left_finite(domain.ranges[pos]) + else None + ) + stop = ( + rng.stop - domain.ranges[pos].start + if common.UnitRange.is_right_finite(domain.ranges[pos]) + else None + ) + return slice(start, stop) elif common.is_int_index(rng): + assert common.Domain.is_finite(domain) return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 93f17b1eb8..278dde9180 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -192,7 +192,7 @@ def broadcast( np.asarray(field)[ tuple([np.newaxis] * len(dims)) ], # TODO(havogt) use FunctionField once available - domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinity()] * len(dims))), + domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))), ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a4f32929db..ef70a2e645 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1059,7 +1059,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: - return common.Domain((self._dimension, common.UnitRange.infinity())) + return common.Domain((self._dimension, common.UnitRange.infinite())) @property def codomain(self) -> type[core_defs.int32]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 65f9d9d71a..037c4f3e4d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,10 +24,9 @@ import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next.common import Dimension, Domain, UnitRange, is_field -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.otf.compilation import cache +from gt4py.next import common +from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG @@ -40,7 +39,8 @@ cp = None -def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: +def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRange]: + assert common.Domain.is_finite(domain) sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] @@ -54,7 +54,7 @@ def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: def convert_arg(arg: Any): - if is_field(arg): + if common.is_field(arg): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] @@ -67,9 +67,11 @@ def convert_arg(arg: Any): def preprocess_program( - program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode + program: itir.FencilDefinition, + offset_provider: Mapping[str, Any], + lift_mode: itir_transforms.LiftMode, ): - node = apply_common_transforms( + node = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, lift_mode=lift_mode, @@ -81,7 +83,7 @@ def preprocess_program( if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): fencil_definition = node else: - fencil_definition = apply_common_transforms( + fencil_definition = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -109,7 +111,7 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]], + neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { @@ -134,7 +136,7 @@ def get_offset_args( return { str(sym): -drange.start for param, arg in zip(params, args) - if is_field(arg) + if common.is_field(arg) for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) } @@ -162,13 +164,19 @@ def get_stride_args( def get_cache_id( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - column_axis: Optional[Dimension], + column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: max_neighbors = [ (k, v.max_neighbors) for k, v in offset_provider.items() - if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + if isinstance( + v, + ( + itir_embedded.NeighborTableOffsetProvider, + itir_embedded.StridedNeighborOffsetProvider, + ), + ) ] cache_id_args = [ str(arg) @@ -191,8 +199,8 @@ def build_sdfg_from_itir( offset_provider: dict[str, Any], auto_optimize: bool = False, on_gpu: bool = False, - column_axis: Optional[Dimension] = None, - lift_mode: LiftMode = LiftMode.FORCE_INLINE, + column_axis: Optional[common.Dimension] = None, + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, ) -> dace.SDFG: """Translate a Fencil into an SDFG. @@ -210,7 +218,7 @@ def build_sdfg_from_itir( """ # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. - lift_mode = LiftMode.FORCE_INLINE + lift_mode = itir_transforms.LiftMode.FORCE_INLINE arg_types = [type_translation.from_value(arg) for arg in args] device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU @@ -237,7 +245,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) - lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE) + lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] @@ -263,7 +271,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): lift_mode=lift_mode, ) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1a38e5245e..6863b09c12 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import dataclasses import itertools import math import operator @@ -20,7 +19,7 @@ import numpy as np import pytest -from gt4py.next import common, embedded +from gt4py.next import common from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice @@ -353,7 +352,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinity())), + Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), ) ), ( @@ -362,7 +361,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange.infinity(), UnitRange(0, 10))), + Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), ) ), ( @@ -373,7 +372,7 @@ def test_cartesian_remap_implementation(): ), Domain( dims=(IDim, JDim, KDim), - ranges=(UnitRange.infinity(), UnitRange(0, 10), UnitRange.infinity()), + ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), ), ) ), diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index bafabfb56e..7650e90c3c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -14,6 +14,7 @@ import operator from typing import Optional, Pattern +import numpy as np import pytest from gt4py.next.common import ( @@ -41,6 +42,56 @@ def a_domain(): return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) +@pytest.fixture(params=[Infinity.POSITIVE, Infinity.NEGATIVE]) +def unbounded(request): + yield request.param + + +def test_unbounded_add_sub(unbounded): + assert unbounded + 1 == unbounded + assert unbounded - 1 == unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.le, operator.lt]) +def test_unbounded_comparison_less(value, op): + assert not op(Infinity.POSITIVE, value) + assert op(value, Infinity.POSITIVE) + + assert op(Infinity.NEGATIVE, value) + assert not op(value, Infinity.NEGATIVE) + + assert op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.ge, operator.gt]) +def test_unbounded_comparison_greater(value, op): + assert op(Infinity.POSITIVE, value) + assert not op(value, Infinity.POSITIVE) + + assert not op(Infinity.NEGATIVE, value) + assert op(value, Infinity.NEGATIVE) + + assert not op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +def test_unbounded_eq(unbounded): + assert unbounded == unbounded + assert unbounded <= unbounded + assert unbounded >= unbounded + assert not unbounded < unbounded + assert not unbounded > unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +def test_unbounded_max_min(value): + assert max(Infinity.POSITIVE, value) == Infinity.POSITIVE + assert min(Infinity.POSITIVE, value) == value + assert max(Infinity.NEGATIVE, value) == value + assert min(Infinity.NEGATIVE, value) == Infinity.NEGATIVE + + def test_empty_range(): expected = UnitRange(0, 0) assert UnitRange(1, 1) == expected @@ -58,9 +109,20 @@ def test_unit_range_length(rng): assert len(rng) == 10 -@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) -def test_unit_range_like(rng_like): - assert unit_range(rng_like) == UnitRange(2, 4) +@pytest.mark.parametrize( + "rng_like, expected", + [ + ((2, 4), UnitRange(2, 4)), + (range(2, 4), UnitRange(2, 4)), + (UnitRange(2, 4), UnitRange(2, 4)), + ((None, None), UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ((2, None), UnitRange(2, Infinity.POSITIVE)), + ((None, 4), UnitRange(Infinity.NEGATIVE, 4)), + (None, UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ], +) +def test_unit_range_like(rng_like, expected): + assert unit_range(rng_like) == expected def test_unit_range_repr(rng): @@ -94,13 +156,6 @@ def test_unit_range_slice_error(rng): rng[1:2:5] -def test_unit_range_set_intersection(rng): - with pytest.raises( - NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." - ): - rng & {1, 5} - - @pytest.mark.parametrize( "rng1, rng2, expected", [ @@ -121,46 +176,65 @@ def test_unit_range_intersection(rng1, rng2, expected): @pytest.mark.parametrize( "rng1, rng2, expected", [ - (UnitRange(20, Infinity.positive()), UnitRange(10, 15), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(5, 10), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(-10, 0), UnitRange(-10, 0)), - (UnitRange(0, Infinity.positive()), UnitRange(Infinity.negative(), 5), UnitRange(0, 5)), + (UnitRange(20, Infinity.POSITIVE), UnitRange(10, 15), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(5, 10), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(-10, 0), UnitRange(-10, 0)), + (UnitRange(0, Infinity.POSITIVE), UnitRange(Infinity.NEGATIVE, 5), UnitRange(0, 5)), ( - UnitRange(Infinity.negative(), 0), - UnitRange(Infinity.negative(), 5), - UnitRange(Infinity.negative(), 0), + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(Infinity.NEGATIVE, 5), + UnitRange(Infinity.NEGATIVE, 0), ), ( - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), ), ], ) -def test_unit_range_infinite_intersection(rng1, rng2, expected): +def test_unit_range_unbounded_intersection(rng1, rng2, expected): result = rng1 & rng2 assert result == expected -def test_positive_infinity_range(): - pos_inf_range = UnitRange(Infinity.positive(), Infinity.positive()) - assert len(pos_inf_range) == 0 +@pytest.mark.parametrize( + "rng", + [ + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(0, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + ], +) +def test_positive_infinite_range_len(rng): + with pytest.raises(ValueError, match=r".*open.*"): + len(rng) -def test_mixed_infinity_range(): - mixed_inf_range = UnitRange(Infinity.negative(), Infinity.positive()) - assert len(mixed_inf_range) == Infinity.positive() +def test_range_contains(): + assert 1 in UnitRange(0, 2) + assert 1 not in UnitRange(0, 1) + assert 1 in UnitRange(0, Infinity.POSITIVE) + assert 1 in UnitRange(Infinity.NEGATIVE, 2) + assert 1 in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) + assert "s" not in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) @pytest.mark.parametrize( "op, rng1, rng2, expected", [ (operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True), - (operator.le, UnitRange(-1, 2), {-1, 0, 1}, True), - (operator.le, UnitRange(-1, 2), {-1, 0}, False), - (operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True), - (operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True), - (operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False), + (operator.le, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ge, UnitRange(-2, 3), UnitRange(-1, 2), True), + (operator.ge, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.lt, UnitRange(-1, 2), UnitRange(-2, 2), True), + (operator.lt, UnitRange(-2, 1), UnitRange(-2, 2), True), + (operator.lt, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-1, 2), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-2, 1), True), + (operator.gt, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.eq, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), False), ], ) def test_range_comparison(op, rng1, rng2, expected): From 6c7c5d51b440c40175a25fb75fcbde7c919afbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 18 Dec 2023 10:58:51 +0100 Subject: [PATCH 61/67] feat[dace]: Buildflags to the `ITIR -> SDFG` translation (#1389) Made it possible to also pass build flags to the `ITIR -> SDFG` translator. --- .../runners/dace_iterator/__init__.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 037c4f3e4d..59569de30b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -47,10 +47,6 @@ def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRa """ Default build configuration in DaCe backend """ _build_type = "Release" -# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins -_cpu_args = ( - "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" -) def convert_arg(arg: Any): @@ -242,6 +238,7 @@ def build_sdfg_from_itir( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters build_cache = kwargs.get("build_cache", None) + compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) @@ -274,7 +271,10 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) - dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + if compiler_args is not None: + dace.config.Config.set( + "compiler", "cuda" if on_gpu else "cpu", "args", value=compiler_args + ) sdfg_program = sdfg.compile(validate=False) # store SDFG program in build cache @@ -312,12 +312,21 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + compiler_args = dace.config.Config.get("compiler", "cpu", "args") + + # disable finite-math-only in order to support isfinite/isinf/isnan builtins + if "-ffast-math" in compiler_args: + compiler_args += " -fno-finite-math-only" + if "-ffinite-math-only" in compiler_args: + compiler_args.replace("-ffinite-math-only", "") + run_dace_iterator( program, *args, **kwargs, build_cache=_build_cache_cpu, build_type=_build_type, + compiler_args=compiler_args, on_gpu=False, ) From 315d9203bb667baa3daaea4b797a0846a2b70887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:35:51 +0100 Subject: [PATCH 62/67] feat[dace]: Computing SDFG call arguments (#1398) Added a function to get the arguments to call an SDFG. This commit adds a function that allows to generate the arguments needed to call an SDFG, before this was part of `run_dace_iterator()`. This made it very complex to run an SDFG outside this function. One should consider this as an amend to [PR #1379](https://github.com/GridTools/gt4py/pull/1379). --- .../runners/dace_iterator/__init__.py | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 59569de30b..97dd90eb54 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -90,8 +90,9 @@ def preprocess_program( return fencil_definition -def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: + sdfg_params: Sequence[str] = sdfg.arg_names + return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)} def _ensure_is_on_device( @@ -127,13 +128,16 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] + sdfg: dace.SDFG, + args: Sequence[Any], ) -> Mapping[str, int]: + sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays + sdfg_params: Sequence[str] = sdfg.arg_names return { str(sym): -drange.start - for param, arg in zip(params, args) + for sdfg_param, arg in zip(sdfg_params, args) if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -189,6 +193,45 @@ def get_cache_id( return m.hexdigest() +def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: + """Extracts the arguments needed to call the SDFG. + + This function can handle the same arguments that are passed to `run_dace_iterator()`. + + Args: + sdfg: The SDFG for which we want to get the arguments. + """ # noqa: D401 + offset_provider = kwargs["offset_provider"] + on_gpu = kwargs.get("on_gpu", False) + + neighbor_tables = filter_neighbor_tables(offset_provider) + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + dace_args = get_args(sdfg, args) + dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_conn_args = get_connectivity_args(neighbor_tables, device) + dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) + dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) + dace_strides = get_stride_args(sdfg.arrays, dace_field_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg, args) + all_args = { + **dace_args, + **dace_conn_args, + **dace_shapes, + **dace_conn_shapes, + **dace_strides, + **dace_conn_strides, + **dace_offsets, + } + expected_args = { + key: value + for key, value in all_args.items() + if key in sdfg.signature_arglist(with_types=False) + } + return expected_args + + def build_sdfg_from_itir( program: itir.FencilDefinition, *args, @@ -248,8 +291,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: @@ -281,29 +322,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): if build_cache is not None: build_cache[cache_id] = sdfg_program - dace_args = get_args(program.params, args) - dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables, device) - dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) - dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, program.params, args) - - all_args = { - **dace_args, - **dace_conn_args, - **dace_shapes, - **dace_conn_shapes, - **dace_strides, - **dace_conn_strides, - **dace_offsets, - } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = get_sdfg_args(sdfg, *args, **kwargs) with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) From 15a7bd627d9fc818befd5f6ff6e795868563ff37 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Dec 2023 08:43:41 +0100 Subject: [PATCH 63/67] fix[next][dace]: Fix memlet for array slicing (#1399) Implementation of array slicing in DaCe backend changed to a mapped tasklet. Tested on GPU. CUDA code generation did not support the previous implementation, based on memlet in nested-SDFG. --- .../runners/dace_iterator/itir_to_tasklet.py | 66 ++++++------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d08476847f..4c202b1fe8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,7 +18,6 @@ import dace import numpy as np -from dace import subsets from dace.transformation.dataflow import MapFusion import gt4py.eve.codegen @@ -754,52 +753,29 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] ] - # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset - deref_sdfg = dace.SDFG("deref") - deref_sdfg.add_array( - "_inp", field_array.shape, iterator.dtype, strides=field_array.strides - ) - for connector in deref_connectors[1:]: - deref_sdfg.add_scalar(connector, _INDEX_DTYPE) - deref_sdfg.add_array("_out", result_shape, iterator.dtype) - deref_init_state = deref_sdfg.add_state("init", True) - deref_access_state = deref_sdfg.add_state("access") - deref_sdfg.add_edge( - deref_init_state, - deref_access_state, - dace.InterstateEdge( - assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} - ), - ) - # we access the size in source field shape as symbols set on the nested sdfg - source_subset = tuple( - f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + # we create a mapped tasklet for array slicing + map_ranges = { + f"_i_{dim}": f"0:{size}" for dim, size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + } + src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims]) + dst_subset = ",".join( + [f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices] ) - deref_access_state.add_nedge( - deref_access_state.add_access("_inp"), - deref_access_state.add_access("_out"), - dace.Memlet( - data="_out", - subset=subsets.Range.from_array(result_array), - other_subset=",".join(source_subset), - ), - ) - - deref_node = self.context.state.add_nested_sdfg( - deref_sdfg, - self.context.body, - inputs=set(deref_connectors), - outputs={"_out"}, - ) - for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): - self.context.state.add_edge(node, None, deref_node, connector, memlet) - self.context.state.add_edge( - deref_node, - "_out", - result_node, - None, - dace.Memlet.from_array(result_name, result_array), + self.context.state.add_mapped_tasklet( + "deref", + map_ranges, + inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, + outputs={ + "_out": dace.Memlet.from_array(result_name, result_array), + }, + code=f"_out[{dst_subset}] = _inp[{src_subset}]", + external_edges=True, + input_nodes={node.data: node for node in deref_nodes}, + output_nodes={ + result_name: result_node, + }, ) return [ValueExpr(result_node, iterator.dtype)] From af33e21fab16fb3de13ec5721b050dada63e220c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:31:23 +0100 Subject: [PATCH 64/67] fix[dace]: Fixed SDFG args (#1400) Modified how the SDFG arguments are computed. It was noticed that some transformations, especially the `SDFG.apply_gpu_transformation()`, to the SDFG, added new arguments to the SDFG. But, since a lot of functions build on the `SDFG.arg_names` member and this member was populated before the transformation, an error occurred. Thus it was changed such that `SDFG.arg_names` was only populated with the arguments also known to the Fencil. --- .../runners/dace_iterator/__init__.py | 17 ++++++++--------- .../runners/dace_iterator/itir_to_sdfg.py | 11 +++-------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 97dd90eb54..7fd4794e57 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -207,6 +207,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: neighbor_tables = filter_neighbor_tables(offset_provider) device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + sdfg_sig = sdfg.signature_arglist(with_types=False) dace_args = get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_connectivity_args(neighbor_tables, device) @@ -224,11 +225,8 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: **dace_conn_strides, **dace_offsets, } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = {key: all_args[key] for key in sdfg_sig} + return expected_args @@ -258,21 +256,22 @@ def build_sdfg_from_itir( # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. lift_mode = itir_transforms.LiftMode.FORCE_INLINE - arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) + # TODO: According to Lex one should build the SDFG first in a general mannor. + # Generalisation to a particular device should happen only at the end. sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() # run DaCe auto-optimization heuristics if auto_optimize: - # TODO Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. symbols: dict[str, int] = {} + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index b3e6662623..e3b5ddf2ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -209,14 +209,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. - # All arguments required by the SDFG, regardless if explicit and implicit, are added - # as positional arguments. In the front are all arguments to the Fencil, in that - # order, they are followed by the arguments created by the translation process, - arg_list = [str(a) for a in node.params] - sig_list = program_sdfg.signature_arglist(with_types=False) - implicit_args = set(sig_list) - set(arg_list) - call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] - program_sdfg.arg_names = call_params + # Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments. + # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. + program_sdfg.arg_names = [str(a) for a in node.params] program_sdfg.validate() return program_sdfg From b21dd566bcbd805279d94f36a20c5ea34a300d97 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Dec 2023 12:06:59 +0100 Subject: [PATCH 65/67] feat[next]: Test for local dimension in output (#1392) Currently only supported in field view embedded. --- pyproject.toml | 1 + tests/next_tests/exclusion_matrices.py | 3 +++ .../ffront_tests/test_external_local_field.py | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 2cf4fb12e2..5d7a2f2cb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -342,6 +342,7 @@ markers = [ 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_sparse_fields: tests that require backend support for sparse fields', + 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', 'uses_tuple_returns: tests that require backend support for tuple results', diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 3c42a180dd..f6d2b10a14 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -95,6 +95,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" @@ -119,6 +120,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -159,4 +161,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], + ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 42938e2f4b..698dce2b5c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -82,3 +82,22 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 out=cases.allocate(unstructured_case, testee, cases.RETURN)(), ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), ) + + +@pytest.mark.uses_sparse_fields_as_output +def test_write_local_field(unstructured_case): + @gtx.field_operator + def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: + return inp(V2E) + + out = unstructured_case.as_field( + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + cases.verify( + unstructured_case, + testee, + inp, + out=out, + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ) From 100bc7fee17e9235da070e1bbf0fedd615de541f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 3 Jan 2024 12:09:23 +0100 Subject: [PATCH 66/67] Add missing grid_type argument to scan operator decorator (#1404) --- src/gt4py/next/ffront/decorator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4abd8f156a..53159008f0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -775,6 +775,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> FieldOperator[foast.ScanOperator]: ... @@ -786,6 +787,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... @@ -797,6 +799,7 @@ def scan_operator( forward: bool = True, init: core_defs.Scalar = 0.0, backend=None, + grid_type: GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] @@ -834,6 +837,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) From 7a9489f73ddddd6aff219fc3890bed23e791a9a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 4 Jan 2024 00:47:33 +0100 Subject: [PATCH 67/67] Fix size check in CollapseTuple pass (#1405) --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 3 +++ src/gt4py/next/iterator/type_inference.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7d710fc919..30457f2246 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -41,6 +41,9 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t ): return UnknownLength + if not type_.dtype.has_known_length: + return UnknownLength + return len(type_.dtype) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 2375118cd1..68627cfd89 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -77,6 +77,12 @@ def __iter__(self) -> abc.Iterator[Type]: raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") yield from self.others + @property + def has_known_length(self): + return isinstance(self.others, EmptyTuple) or ( + isinstance(self.others, Tuple) and self.others.has_known_length + ) + def __len__(self) -> int: return sum(1 for _ in self)