diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 880a422160..93ea4685f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.2 + rev: v0.7.3 ##[[[end]]] hooks: # Run the linter. @@ -94,16 +94,17 @@ repos: - astunparse==1.6.3 - attrs==24.2.0 - black==24.8.0 - - boltons==24.0.0 + - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.30.5 + - cmake==3.31.0.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 + - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - - gridtools-cpp==2.3.6 + - gridtools-cpp==2.3.7 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 @@ -111,7 +112,7 @@ repos: - nanobind==2.2.0 - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==24.1 + - packaging==24.2 - pybind11==2.13.6 - setuptools==75.3.0 - tabulate==0.9.0 diff --git a/constraints.txt b/constraints.txt index e846d4126c..4aca6645d5 100644 --- a/constraints.txt +++ b/constraints.txt @@ -13,10 +13,10 @@ attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, babel==2.16.0 # via sphinx backcall==0.2.0 # via ipython black==24.8.0 # via gt4py (pyproject.toml) -boltons==24.0.0 # via gt4py (pyproject.toml) +boltons==24.1.0 # via gt4py (pyproject.toml) bracex==2.5.post1 # via wcmatch build==1.2.2.post1 # via pip-tools -bump-my-version==0.28.0 # via -r requirements-dev.in +bump-my-version==0.28.1 # via -r requirements-dev.in cached-property==2.0.1 # via gt4py (pyproject.toml) cachetools==5.5.0 # via tox certifi==2024.8.30 # via requests @@ -25,7 +25,7 @@ chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via gt4py (pyproject.toml) +cmake==3.31.0.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,11 +35,12 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==0.16.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.7 # via ipykernel +debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.9 # via dace +diskcache==5.6.3 # via gt4py (pyproject.toml) distlib==0.3.9 # via virtualenv docutils==0.20.1 # via sphinx, sphinx-rtd-theme exceptiongroup==1.2.2 # via hypothesis, pytest @@ -54,7 +55,7 @@ fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.6 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via gt4py (pyproject.toml) hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via pre-commit idna==3.10 # via requests @@ -65,7 +66,7 @@ inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake ipython==8.12.3 # via ipykernel -jedi==0.19.1 # via ipython +jedi==0.19.2 # via ipython jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema @@ -94,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff -packaging==24.1 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython @@ -135,10 +136,10 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx -rich==13.9.3 # via bump-my-version, rich-click, tach +rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.2 # via -r requirements-dev.in +ruff==0.7.3 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -158,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.1 # via -r requirements-dev.in +tach==0.14.3 # via -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version @@ -173,7 +174,7 @@ virtualenv==20.27.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit websockets==13.1 # via dace -wheel==0.44.0 # via astunparse, pip-tools +wheel==0.45.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 7fea11bc3d..6fd3d1af55 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -65,9 +65,10 @@ dace==0.16.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 +diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index c20883e25e..b8779096c0 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -61,9 +61,10 @@ cytoolz==0.12.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 +diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 64f08e671e..7d63f70f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,9 +34,10 @@ dependencies = [ 'cytoolz>=0.12.1', 'deepdiff>=5.6.0', 'devtools>=0.6', + 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.6,==2.*', + 'gridtools-cpp>=2.3.7,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index eb757e0afd..8892620786 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,10 +13,10 @@ attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypo babel==2.16.0 # via -c constraints.txt, sphinx backcall==0.2.0 # via -c constraints.txt, ipython black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml) -boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml) +boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml) bracex==2.5.post1 # via -c constraints.txt, wcmatch build==1.2.2.post1 # via -c constraints.txt, pip-tools -bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in +bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cachetools==5.5.0 # via -c constraints.txt, tox certifi==2024.8.30 # via -c constraints.txt, requests @@ -25,7 +25,7 @@ chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,11 +35,12 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.7 # via -c constraints.txt, ipykernel +debugpy==1.8.8 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) dill==0.3.9 # via -c constraints.txt, dace +diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml) distlib==0.3.9 # via -c constraints.txt, virtualenv docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest @@ -54,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -65,7 +66,7 @@ inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake ipython==8.12.3 # via -c constraints.txt, ipykernel -jedi==0.19.1 # via -c constraints.txt, ipython +jedi==0.19.2 # via -c constraints.txt, ipython jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema @@ -94,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.1 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython @@ -135,10 +136,10 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx -rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach +rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in +ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb @@ -157,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version @@ -172,7 +173,7 @@ virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit websockets==13.1 # via -c constraints.txt, dace -wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 560cbf96cf..d36c2e5c4a 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -7,11 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass, field -from typing import Any, List, Optional, Set, Union +from typing import Any, List, Set, Union from gt4py import eve -from gt4py.cartesian.gtc import common, gtir, oir, utils -from gt4py.cartesian.gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator +from gt4py.cartesian.gtc import gtir, oir, utils +from gt4py.cartesian.gtc.common import CartesianOffset, DataType, UnaryOperator from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents @@ -118,15 +118,8 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: ) # --- Statements --- - def visit_ParAssignStmt( - self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> Union[oir.AssignStmt, oir.MaskStmt]: - statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) - if mask is None: - return statement - - # Wrap inside MaskStmt - return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc) + def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> oir.AssignStmt: + return oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any @@ -138,24 +131,19 @@ def visit_HorizontalRestriction( return oir.HorizontalRestriction(mask=node.mask, body=body) - def visit_While( - self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> oir.While: + def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While: body: List[oir.Stmt] = [] for statement in node.body: oir_statement = self.visit(statement, **kwargs) body.extend(utils.flatten_list(utils.listify(oir_statement))) condition: oir.Expr = self.visit(node.cond) - if mask: - condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition) return oir.While(cond=condition, body=body, loc=node.loc) def visit_FieldIfStmt( self, node: gtir.FieldIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: @@ -182,26 +170,17 @@ def visit_FieldIfStmt( loc=node.loc, ) - combined_mask: oir.Expr = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=condition, loc=node.loc)) if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition) - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements @@ -211,31 +190,21 @@ def visit_ScalarIfStmt( self, node: gtir.ScalarIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[oir.MaskStmt]: condition = self.visit(node.cond) - combined_mask = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc - ) - body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)] if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) - if mask: - combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) - + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index ed573ebfff..b6aeb49823 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -157,13 +157,12 @@ def visit_AssignStmt( def visit_While( self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any ) -> npir.While: - cond = self.visit(node.cond, mask=mask, **kwargs) + cond_expr = self.visit(node.cond, **kwargs) if mask: - mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) - else: - mask = cond + cond_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond_expr) + return npir.While( - cond=cond, body=utils.flatten_list(self.visit(node.body, mask=mask, **kwargs)) + cond=cond_expr, body=utils.flatten_list(self.visit(node.body, mask=cond_expr, **kwargs)) ) def visit_HorizontalRestriction( diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index f415c95b63..09f53be600 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG: ) -def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]: +def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]: """ Wrap `func_to_past` in a chainable and optionally cached workflow step. diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index bf3bee4b56..834536ff59 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(toolchain.CompilableProgram) @add_content_to_fingerprint.register(arguments.CompileTimeArgs) @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo for item in sourcedef: add_content_to_fingerprint(item, hasher) + closure_vars = source_utils.get_closure_vars_from_function(obj) + for item in sorted(closure_vars.items(), key=lambda x: x[0]): + add_content_to_fingerprint(item, hasher) + @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in obj.items(): + for key, value in sorted(obj.items()): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint( ) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) - add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b6f543e9d1..f50d8080eb 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -208,7 +208,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): closures: List[StencilClosure] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(BUILTINS) + ] # sorted for serialization stability class Stmt(Node): ... diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index a63801c97e..ef3a4083b9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -12,6 +12,7 @@ import dataclasses import functools import typing +from collections.abc import MutableMapping from typing import Any, Callable, Generic, Protocol, TypeVar from typing_extensions import Self @@ -253,16 +254,15 @@ class CachedStep( step: Workflow[StartT, EndT] hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - - _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) + cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" hash_ = self.hash_function(inp) try: - result = self._cache[hash_] + result = self.cache[hash_] except KeyError: - result = self._cache[hash_] = self.step(inp) + result = self.cache[hash_] = self.step(inp) return result diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 07eec0b64b..66d74d53cc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -213,6 +213,7 @@ def generate_stencil_source( generated_code = GTFNIMCodegen.apply(gtfn_im_ir) else: generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 2db8e98804..9a45b6a29a 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,8 +8,8 @@ import factory -from gt4py.next import allocators as next_allocators, backend -from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next import backend +from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory @@ -25,12 +25,12 @@ class Params: ), ) auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_temps="_opt" + otf_workflow__translation__auto_optimize=True, name_postfix="_opt" ) use_field_canonical_representation: bool = False name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" + lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" ) transforms = backend.DEFAULT_TRANSFORMS @@ -45,12 +45,28 @@ class Params: itir_cpu = run_dace_cpu itir_gpu = run_dace_gpu -gtir_cpu = backend.Backend( - name="dace.gtir.cpu", - executor=dace_fieldview_workflow.DaCeWorkflowFactory(), - allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms=backend.Transforms( + +class DaCeFieldviewBackendFactory(GTFNBackendFactory): + class Params: + otf_workflow = factory.SubFactory( + dace_fieldview_workflow.DaCeWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + ) + auto_optimize = factory.Trait(name_postfix="_opt") + + name = factory.LazyAttribute( + lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" + ) + + transforms = backend.Transforms( past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), - ), -) + foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(), + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() + ), + ) + + +gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) +gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 5d3cc7a358..bbf45a822c 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -32,7 +32,7 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: # Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar, # which may change its precision. To avoid this, we use here the empty tuple as index # for 'ndarray.__getitem__()'. - return arg.ndarray[()] + return arg.asnumpy()[()] # field domain offsets are not supported non_zero_offsets = [ (dim, dim_range) @@ -88,10 +88,19 @@ def _get_shape_args( for name, value in args.items(): for sym, size in zip(arrays[name].shape, value.shape, strict=True): if isinstance(sym, dace.symbol): - assert sym.name not in shape_args - shape_args[sym.name] = size + if sym.name not in shape_args: + shape_args[sym.name] = size + elif shape_args[sym.name] != size: + # The same shape symbol is used by all fields of a tuple, because the current assumption is that all fields + # in a tuple have the same dimensions and sizes. Therefore, this if-branch only exists to ensure that array + # size (i.e. the value assigned to the shape symbol) is the same for all fields in a tuple. + # TODO(edopao): change to `assert sym.name not in shape_args` to ensure that shape symbols are unique, + # once the assumption on tuples is removed. + raise ValueError( + f"Expected array size {sym.name} for arg {name} to be {shape_args[sym.name]}, got {size}." + ) elif sym != size: - raise RuntimeError( + raise ValueError( f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." ) return shape_args @@ -109,10 +118,17 @@ def _get_stride_args( f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) if isinstance(sym, dace.symbol): - assert sym.name not in stride_args - stride_args[str(sym)] = stride + if sym.name not in stride_args: + stride_args[str(sym)] = stride + elif stride_args[sym.name] != stride: + # See above comment in `_get_shape_args`, same for stride symbols of fields in a tuple. + # TODO(edopao): change to `assert sym.name not in stride_args` to ensure that stride symbols are unique, + # once the assumption on tuples is removed. + raise ValueError( + f"Expected array stride {sym.name} for arg {name} to be {stride_args[sym.name]}, got {stride}." + ) elif sym != stride: - raise RuntimeError( + raise ValueError( f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}." ) return stride_args diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index ae0a24605d..91e83dba9d 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -17,7 +17,7 @@ from dace.codegen.compiled_sdfg import _array_interface_ptr as get_array_interface_ptr from gt4py._core import definitions as core_defs -from gt4py.next import common, config +from gt4py.next import common, config, utils as gtx_utils from gt4py.next.otf import arguments, languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils @@ -116,7 +116,7 @@ def decorated_program( args = (*args, *arguments.iter_size_args(args)) if sdfg_program._lastargs: - kwargs = dict(zip(sdfg.arg_names, args, strict=True)) + kwargs = dict(zip(sdfg.arg_names, gtx_utils.flatten_nested_tuple(args), strict=True)) kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) use_fast_call = True diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 48c666a363..da940e883c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -217,6 +217,7 @@ def _add_storage( name: str, gt_type: ts.DataType, transient: bool = True, + tuple_name: Optional[str] = None, ) -> list[tuple[str, ts.DataType]]: """ Add storage in the SDFG for a given GT4Py data symbol. @@ -236,6 +237,7 @@ def _add_storage( name: Symbol Name to be allocated. gt_type: GT4Py symbol type. transient: True when the data symbol has to be allocated as internal storage. + tuple_name: Must be set for tuple fields in order to use the same array shape and strides symbols. Returns: List of tuples '(data_name, gt_type)' where 'data_name' is the name of @@ -250,7 +252,9 @@ def _add_storage( name, gt_type, flatten=True ): tuple_fields.extend( - self._add_storage(sdfg, symbolic_arguments, tname, tsymbol_type, transient) + self._add_storage( + sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + ) ) return tuple_fields @@ -260,16 +264,23 @@ def _add_storage( return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions dc_dtype = dace_utils.as_dace_type(gt_type.dtype) - # use symbolic shape, which allows to invoke the program with fields of different size; - # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) + if tuple_name is None: + # Use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) + else: + # All fields in a tuple must have the same dims and sizes, + # therefore we use the same shape and strides symbols based on 'tuple_name'. + sym_shape, sym_strides = self._make_array_shape_and_strides( + tuple_name, gt_type.dims + ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): dc_dtype = dace_utils.as_dace_type(gt_type) - if name in symbolic_arguments: + if dace_utils.is_field_symbol(name) or name in symbolic_arguments: if name in sdfg.symbols: # Sometimes, when the field domain is implicitly derived from the # field domain, the gt4py lowering adds the field size as a scalar @@ -698,49 +709,61 @@ def _flatten_tuples( head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) - def make_temps( - output_data: gtir_builtin_translators.FieldopData, + def construct_output_for_nested_sdfg( + inner_data: gtir_builtin_translators.FieldopData, ) -> gtir_builtin_translators.FieldopData: """ - This function will be called while traversing the result of the lambda - dataflow to setup the intermediate data nodes in the parent SDFG and - the data edges from the nested-SDFG output connectors. + This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, + available in the parent SDFG. + In order to achieve this, the data container inside the nested SDFG is marked as non-transient + (in other words, externally allocated - a requirement of the SDFG IR) and a new data container + is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` + but appropriatly remapped using the symbol mapping table. + For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped + to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. + The same happens to symbols available in the lambda context but not explicitly passed as lambda + arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. """ - desc = output_data.dc_node.desc(nsdfg) - if desc.transient: - # Transient nodes actually contain some result produced by the dataflow - # itself, therefore these nodes are changed to non-transient and an output - # edge will write the result from the nested-SDFG to a new intermediate - # data node in the parent context. - desc.transient = False - temp, _ = sdfg.add_temp_transient_like(desc) - connector = output_data.dc_node.data - dst_node = head_state.add_access(temp) + inner_desc = inner_data.dc_node.desc(nsdfg) + if inner_desc.transient: + # Transient data nodes only exist within the nested SDFG. In order to return some result data, + # the corresponding data container inside the nested SDFG has to be changed to non-transient, + # that is externally allocated, as required by the SDFG IR. An output edge will write the result + # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. + inner_desc.transient = False + outer, outer_desc = sdfg.add_temp_transient_like(inner_desc) + # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. + dace.symbolic.safe_replace( + nsdfg_symbols_mapping, + lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + ) + connector = inner_data.dc_node.data + outer_node = head_state.add_access(outer) head_state.add_edge( - nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp) + nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - temp_field = gtir_builtin_translators.FieldopData( - dst_node, output_data.gt_dtype, output_data.local_offset + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_dtype, inner_data.local_offset ) - elif output_data.dc_node.data in lambda_arg_nodes: + elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - temp_field = lambda_arg_nodes[output_data.dc_node.data] + outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: - dc_node = head_state.add_access(output_data.dc_node.data) - temp_field = gtir_builtin_translators.FieldopData( - dc_node, output_data.gt_dtype, output_data.local_offset + outer_node = head_state.add_access(inner_data.dc_node.data) + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_dtype, inner_data.local_offset ) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. - if nstate.degree(output_data.dc_node) == 0: - nstate.remove_node(output_data.dc_node) - return temp_field + if nstate.degree(inner_data.dc_node) == 0: + nstate.remove_node(inner_data.dc_node) + return outer_data - return gtx_utils.tree_map(make_temps)(lambda_result) + return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) def visit_Literal( self, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index f2953eb05f..85ae95c432 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -16,14 +16,16 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config +from gt4py.next import allocators as gtx_allocators, common, config from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg -from gt4py.next.type_system import type_translation as tt +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_sdfg, + transformations as gtx_transformations, +) @dataclasses.dataclass(frozen=True) @@ -33,7 +35,8 @@ class DaCeTranslator( ], step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + device_type: core_defs.DeviceType + auto_optimize: bool def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -45,9 +48,18 @@ def generate_sdfg( ir: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], + auto_opt: bool, + on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) - return gtir_sdfg.build_sdfg_from_gtir(ir=ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir(ir, offset_provider=offset_provider) + + if auto_opt: + gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) + elif on_gpu: + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=False) + + return sdfg def __call__( self, inp: stages.CompilableProgram @@ -60,11 +72,13 @@ def __call__( program, inp.args.offset_provider, inp.args.column_axis, + auto_opt=self.auto_optimize, + on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), ) param_types = tuple( - interface.Parameter(param, tt.from_value(arg)) - for param, arg in zip(sdfg.arg_names, inp.args.args) + interface.Parameter(param, arg_type) + for param, arg_type in zip(sdfg.arg_names, inp.args.args) ) module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( @@ -98,10 +112,12 @@ class Params: cmake_build_type: config.CMakeBuildType = factory.LazyFunction( lambda: config.CMAKE_BUILD_TYPE ) + auto_optimize: bool = False translation = factory.SubFactory( DaCeTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), ) bindings = _no_bindings compilation = factory.SubFactory( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a824760ce4..a0f4b83d35 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -15,6 +15,7 @@ import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind from gt4py.next.common import Connectivity +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -103,7 +104,7 @@ def _make_array_shape_and_strides( tuple(shape, strides) The output tuple fields are arrays of dace symbolic expressions. """ - dtype = dace.int32 + dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) neighbor_tables = dace_utils.filter_connectivities(offset_provider) shape = [ diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d808fbfbe1..d367eb0883 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -14,6 +14,7 @@ import gt4py.next.iterator.ir as itir from gt4py import eve from gt4py.next.common import Connectivity +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -132,9 +133,11 @@ def unique_var_name(): def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.int64 - shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)] - strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i in range(ndim)] + dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) + shape = [dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) for i in range(ndim)] + strides = [ + dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i in range(ndim) + ] return shape, strides diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2275576081..4a788bf40c 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -8,16 +8,19 @@ import functools import warnings -from typing import Any +from typing import Any, Optional +import diskcache import factory import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators +from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import transforms +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir, transforms from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -116,6 +119,37 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int: ) +def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: + """ + Generates a unique hash string for a stencil source program representing + the program, sorted offset_provider, and column_axis. + """ + program: itir.FencilDefinition | itir.Program = inp.data + offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + column_axis: Optional[common.Dimension] = inp.args.column_axis + + program_hash = utils.content_hash( + ( + program, + sorted(offset_provider.items(), key=lambda el: el[0]), + column_axis, + ) + ) + + return program_hash + + +class FileCache(diskcache.Cache): + """ + This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, + i.e. it ensures that any resources associated with the cache are properly + released when the instance is garbage collected. + """ + + def __del__(self) -> None: + self.close() + + class GTFNCompileWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -129,10 +163,23 @@ class Params: lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) - translation = factory.SubFactory( - gtfn_module.GTFNTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - ) + cached_translation = factory.Trait( + translation=factory.LazyAttribute( + lambda o: workflow.CachedStep( + o.translation_, + hash_function=fingerprint_compilable_program, + cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + ) + ), + ) + + translation_ = factory.SubFactory( + gtfn_module.GTFNTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), + ) + + translation = factory.LazyAttribute(lambda o: o.translation_) + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source ) @@ -193,7 +240,7 @@ class Params: name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) -run_gtfn_cached = GTFNBackendFactory(cached=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index d3a5744389..0312aea7c3 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -444,6 +444,36 @@ def validation(field_a, field_b, field_c, *, factor, domain, origin, **kwargs): field_a += 1 +class TestRuntimeIfNestedWhile(gt_testing.StencilTestSuite): + """Test conditional while statements.""" + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ALL_BACKENDS + symbols = dict( + infield=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + with computation(PARALLEL), interval(...): + if infield < 10: + outfield = 1 + done = False + while not done: + outfield = 2 + done = True + else: + condition = True + while condition: + outfield = 4 + condition = False + outfield = 3 + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = 2 + + class TestTernaryOp(gt_testing.StencilTestSuite): dtypes = (np.float_,) domain_range = [(1, 15), (2, 15), (1, 15)] diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py index 4de7f9f5d6..4877a39503 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py @@ -28,6 +28,7 @@ StencilFactory, VerticalLoopFactory, VerticalLoopSectionFactory, + WhileFactory, ) @@ -78,6 +79,18 @@ def test_mask_stmt_to_assigns() -> None: assert len(assign_stmts) == 1 +def test_mask_stmt_to_while() -> None: + mask_oir = MaskStmtFactory(body=[WhileFactory()]) + statements = OirToNpir().visit(mask_oir, extent=Extent.zeros(ndims=2)) + assert len(statements) == 1 + assert isinstance(statements[0], npir.While) + condition = statements[0].cond + assert isinstance(condition, npir.VectorLogic) + assert condition.op == common.LogicalOperator.AND + mask_npir = OirToNpir().visit(mask_oir.mask) + assert condition.left == mask_npir or condition.right == mask_npir + + def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 3fef43865b..1bcc3554a7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -71,6 +71,7 @@ class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" + GTIR_DACE_GPU = "gt4py.next.program_processors.runners.dace.gtir_gpu" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -145,11 +146,14 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = [ - (ALL, SKIP, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), @@ -177,6 +181,11 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST + + [ + # TODO(edopao): Enable when GPU codegen issues related to symbolic domain are fixed. + (ALL, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 306f0034b5..1da34db3c0 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -32,7 +32,10 @@ try: import dace - from gt4py.next.program_processors.runners.dace import run_dace_cpu, run_dace_gpu + from gt4py.next.program_processors.runners.dace import ( + itir_cpu as run_dace_cpu, + itir_gpu as run_dace_gpu, + ) except ImportError: dace: Optional[ModuleType] = None # type:ignore[no-redef] run_dace_cpu: Optional[next_backend.Backend] = None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 333a2dae28..0ed3365969 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -62,12 +62,16 @@ def __gt_allocator__( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, ), + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_GPU, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ], diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7540d52fb3..27f94960dc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -7,9 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import reduce - +from gt4py.next.otf import languages, stages, workflow +from gt4py.next.otf.binding import interface import numpy as np import pytest +import diskcache +from gt4py.eve import SymbolName import gt4py.next as gtx from gt4py.next import ( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index db1c2a42aa..f6fd0a48d0 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -88,6 +88,7 @@ def index_program_shift(out, size): ) +@pytest.mark.starts_from_gtir_program @pytest.mark.uses_index_builtin def test_index_builtin_shift(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e3e0ee474f..e64bd8a57d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -8,13 +8,25 @@ import numpy as np import pytest +import copy +import diskcache + import gt4py.next as gtx from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim + +from next_tests.integration_tests.cases import cartesian_case + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) @pytest.fixture @@ -71,3 +83,103 @@ def test_codegen(fencil_example): assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) assert module.language is languages.CPP + + +def test_hash_and_diskcache(fencil_example, tmp_path): + fencil, parameters = fencil_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + hash = gtfn.fingerprint_compilable_program(compilable_program) + + with diskcache.Cache(tmp_path) as cache: + cache[hash] = compilable_program + + # check content of cash file + with diskcache.Cache(tmp_path) as reopened_cache: + assert hash in reopened_cache + compilable_program_from_cache = reopened_cache[hash] + assert compilable_program == compilable_program_from_cache + del reopened_cache[hash] # delete data + + # hash creation is deterministic + assert hash == gtfn.fingerprint_compilable_program(compilable_program) + assert hash == gtfn.fingerprint_compilable_program(compilable_program_from_cache) + + # hash is different if program changes + altered_program_id = copy.deepcopy(compilable_program) + altered_program_id.data.id = "example2" + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_id) + + altered_program_offset_provider = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_offset_provider.args, "offset_provider", {"Koff": KDim}) + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_offset_provider) + + altered_program_column_axis = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_column_axis.args, "column_axis", KDim) + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) + + +def test_gtfn_file_cache(fencil_example): + fencil, parameters = fencil_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + cached_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ).executor.step.translation + + bare_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=False + ).executor.step.translation + + cache_key = gtfn.fingerprint_compilable_program(compilable_program) + + # ensure the actual cached step in the backend generates the cache item for the test + if cache_key in (translation_cache := cached_gtfn_translation_step.cache): + del translation_cache[cache_key] + cached_gtfn_translation_step(compilable_program) + assert bare_gtfn_translation_step(compilable_program) == cached_gtfn_translation_step( + compilable_program + ) + + assert cache_key in cached_gtfn_translation_step.cache + assert ( + bare_gtfn_translation_step(compilable_program) + == cached_gtfn_translation_step.cache[cache_key] + ) + + +# TODO(egparedes): we should switch to use the cached backend by default and then remove this test +def test_gtfn_file_cache_whole_workflow(cartesian_case): + if cartesian_case.backend != gtfn.run_gtfn: + pytest.skip("Skipping backend.") + cartesian_case.backend = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ) + + @gtx.field_operator + def testee(a: cases.IJKField) -> cases.IJKField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + # first call: this generates the cache file + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) + # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again + object.__setattr__(cartesian_case.backend.executor, "cache", {}) + # second call: the cache file is used + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9f5498b4a7..cc72adae4f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -262,16 +262,8 @@ def test_gtir_tuple_args(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) x_fields = (a, a, b) - x_symbols = dict( - __x_0_size_0=FSYMBOLS["__x_size_0"], - __x_0_stride_0=FSYMBOLS["__x_stride_0"], - __x_1_0_size_0=FSYMBOLS["__x_size_0"], - __x_1_0_stride_0=FSYMBOLS["__x_stride_0"], - __x_1_1_size_0=FSYMBOLS["__y_size_0"], - __x_1_1_stride_0=FSYMBOLS["__y_stride_0"], - ) - sdfg(*x_fields, c, **FSYMBOLS, **x_symbols) + sdfg(*x_fields, c, **FSYMBOLS) assert np.allclose(c, a * 2 + b) @@ -432,16 +424,8 @@ def test_gtir_tuple_return(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_0_1_size_0=FSYMBOLS["__x_size_0"], - __z_0_1_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + sdfg(a, b, *z_fields, **FSYMBOLS) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], a) assert np.allclose(z_fields[2], b) @@ -637,7 +621,7 @@ def test_gtir_cond(): expr=im.op_as_fieldop("plus", domain)( "x", im.if_( - im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")), + im.greater("s1", "s2"), im.op_as_fieldop("plus", domain)("y", "scalar"), im.op_as_fieldop("plus", domain)("w", "scalar"), ), @@ -679,7 +663,7 @@ def test_gtir_cond_with_tuple_return(): expr=im.tuple_get( 0, im.if_( - gtir.SymRef(id="pred"), + "pred", im.make_tuple(im.make_tuple("x", "y"), "w"), im.make_tuple(im.make_tuple("y", "x"), "w"), ), @@ -694,18 +678,11 @@ def test_gtir_cond_with_tuple_return(): b = np.random.rand(N) c = np.random.rand(N) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) for s in [False, True]: z_fields = (np.empty_like(a), np.empty_like(a)) - sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS, **z_symbols) + sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS) assert np.allclose(z_fields[0], a if s else b) assert np.allclose(z_fields[1], b if s else a) @@ -726,10 +703,10 @@ def test_gtir_cond_nested(): body=[ gtir.SetAt( expr=im.if_( - gtir.SymRef(id="pred_1"), + "pred_1", im.op_as_fieldop("plus", domain)("x", 1.0), im.if_( - gtir.SymRef(id="pred_2"), + "pred_2", im.op_as_fieldop("plus", domain)("x", 2.0), im.op_as_fieldop("plus", domain)("x", 3.0), ), @@ -1557,7 +1534,7 @@ def test_gtir_reduce_with_cond_neighbors(): vertex_domain, )( im.if_( - gtir.SymRef(id="pred"), + "pred", im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) @@ -1779,11 +1756,7 @@ def test_gtir_let_lambda_with_cond(): gtir.SetAt( expr=im.let("x1", "x")( im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.if_( - gtir.SymRef(id="pred"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"), - ) + im.if_("pred", "x1", "x2") ) ), domain=domain, @@ -1833,14 +1806,8 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + sdfg(a, b, *z_fields, **FSYMBOLS) assert np.allclose(z_fields[0], a) assert np.allclose(z_fields[1], b) @@ -1879,16 +1846,8 @@ def test_gtir_let_lambda_with_tuple2(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - __z_2_size_0=FSYMBOLS["__x_size_0"], - __z_2_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + sdfg(a, b, *z_fields, **FSYMBOLS) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], val) assert np.allclose(z_fields[2], b) @@ -1939,13 +1898,9 @@ def test_gtir_if_scalars(): d2 = np.random.randint(0, 1000) sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - x_symbols = dict( - __x_0_size_0=FSYMBOLS["__x_size_0"], - __x_0_stride_0=FSYMBOLS["__x_stride_0"], - ) for s in [False, True]: - sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS, **x_symbols) + sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS) assert np.allclose(b, (a + d1 if s else a + d2))