Skip to content

Commit

Permalink
Merge branch 'main' into romanc/cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
romanc authored Nov 13, 2024
2 parents b59c7b7 + 5ce0d9d commit 7b47e97
Show file tree
Hide file tree
Showing 30 changed files with 451 additions and 218 deletions.
11 changes: 6 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"rev: v{version}")
##]]]
rev: v0.7.2
rev: v0.7.3
##[[[end]]]
hooks:
# Run the linter.
Expand Down Expand Up @@ -94,24 +94,25 @@ repos:
- astunparse==1.6.3
- attrs==24.2.0
- black==24.8.0
- boltons==24.0.0
- boltons==24.1.0
- cached-property==2.0.1
- click==8.1.7
- cmake==3.30.5
- cmake==3.31.0.1
- cytoolz==1.0.0
- deepdiff==8.0.1
- devtools==0.12.2
- diskcache==5.6.3
- factory-boy==3.3.1
- frozendict==2.4.6
- gridtools-cpp==2.3.6
- gridtools-cpp==2.3.7
- importlib-resources==6.4.5
- jinja2==3.1.4
- lark==1.2.2
- mako==1.3.6
- nanobind==2.2.0
- ninja==1.11.1.1
- numpy==1.24.4
- packaging==24.1
- packaging==24.2
- pybind11==2.13.6
- setuptools==75.3.0
- tabulate==0.9.0
Expand Down
23 changes: 12 additions & 11 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema,
babel==2.16.0 # via sphinx
backcall==0.2.0 # via ipython
black==24.8.0 # via gt4py (pyproject.toml)
boltons==24.0.0 # via gt4py (pyproject.toml)
boltons==24.1.0 # via gt4py (pyproject.toml)
bracex==2.5.post1 # via wcmatch
build==1.2.2.post1 # via pip-tools
bump-my-version==0.28.0 # via -r requirements-dev.in
bump-my-version==0.28.1 # via -r requirements-dev.in
cached-property==2.0.1 # via gt4py (pyproject.toml)
cachetools==5.5.0 # via tox
certifi==2024.8.30 # via requests
Expand All @@ -25,7 +25,7 @@ chardet==5.2.0 # via tox
charset-normalizer==3.4.0 # via requests
clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml)
click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click
cmake==3.30.5 # via gt4py (pyproject.toml)
cmake==3.31.0.1 # via gt4py (pyproject.toml)
cogapp==3.4.1 # via -r requirements-dev.in
colorama==0.4.6 # via tox
comm==0.2.2 # via ipykernel
Expand All @@ -35,11 +35,12 @@ cycler==0.12.1 # via matplotlib
cytoolz==1.0.0 # via gt4py (pyproject.toml)
dace==0.16.1 # via gt4py (pyproject.toml)
darglint==1.8.1 # via -r requirements-dev.in
debugpy==1.8.7 # via ipykernel
debugpy==1.8.8 # via ipykernel
decorator==5.1.1 # via ipython
deepdiff==8.0.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.9 # via dace
diskcache==5.6.3 # via gt4py (pyproject.toml)
distlib==0.3.9 # via virtualenv
docutils==0.20.1 # via sphinx, sphinx-rtd-theme
exceptiongroup==1.2.2 # via hypothesis, pytest
Expand All @@ -54,7 +55,7 @@ fparser==0.1.4 # via dace
frozendict==2.4.6 # via gt4py (pyproject.toml)
gitdb==4.0.11 # via gitpython
gitpython==3.1.43 # via tach
gridtools-cpp==2.3.6 # via gt4py (pyproject.toml)
gridtools-cpp==2.3.7 # via gt4py (pyproject.toml)
hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml)
identify==2.6.1 # via pre-commit
idna==3.10 # via requests
Expand All @@ -65,7 +66,7 @@ inflection==0.5.1 # via pytest-factoryboy
iniconfig==2.0.0 # via pytest
ipykernel==6.29.5 # via nbmake
ipython==8.12.3 # via ipykernel
jedi==0.19.1 # via ipython
jedi==0.19.2 # via ipython
jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx
jsonschema==4.23.0 # via nbformat
jsonschema-specifications==2023.12.1 # via jsonschema
Expand Down Expand Up @@ -94,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml)
nodeenv==1.9.1 # via pre-commit
numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy
orderly-set==5.2.2 # via deepdiff
packaging==24.1 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox
packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox
parso==0.8.4 # via jedi
pathspec==0.12.1 # via black
pexpect==4.9.0 # via ipython
Expand Down Expand Up @@ -135,10 +136,10 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client
questionary==2.0.1 # via bump-my-version
referencing==0.35.1 # via jsonschema, jsonschema-specifications
requests==2.32.3 # via sphinx
rich==13.9.3 # via bump-my-version, rich-click, tach
rich==13.9.4 # via bump-my-version, rich-click, tach
rich-click==1.8.3 # via bump-my-version
rpds-py==0.20.1 # via jsonschema, referencing
ruff==0.7.2 # via -r requirements-dev.in
ruff==0.7.3 # via -r requirements-dev.in
scipy==1.10.1 # via gt4py (pyproject.toml)
setuptools-scm==8.1.0 # via fparser
six==1.16.0 # via asttokens, astunparse, python-dateutil
Expand All @@ -158,7 +159,7 @@ stack-data==0.6.3 # via ipython
stdlib-list==0.10.0 # via tach
sympy==1.12.1 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
tach==0.14.1 # via -r requirements-dev.in
tach==0.14.3 # via -r requirements-dev.in
tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via tach
tomlkit==0.13.2 # via bump-my-version
Expand All @@ -173,7 +174,7 @@ virtualenv==20.27.1 # via pre-commit, tox
wcmatch==10.0 # via bump-my-version
wcwidth==0.2.13 # via prompt-toolkit
websockets==13.1 # via dace
wheel==0.44.0 # via astunparse, pip-tools
wheel==0.45.0 # via astunparse, pip-tools
xxhash==3.0.0 # via gt4py (pyproject.toml)
zipp==3.20.2 # via importlib-metadata, importlib-resources

Expand Down
3 changes: 2 additions & 1 deletion min-extra-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ dace==0.16.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
frozendict==2.3
gridtools-cpp==2.3.6
gridtools-cpp==2.3.7
hypothesis==6.0.0
importlib-resources==5.0; python_version < "3.9"
jax[cpu]==0.4.18; python_version >= "3.10"
Expand Down
3 changes: 2 additions & 1 deletion min-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ cytoolz==0.12.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
frozendict==2.3
gridtools-cpp==2.3.6
gridtools-cpp==2.3.7
hypothesis==6.0.0
importlib-resources==5.0; python_version < "3.9"
jinja2==3.0.0
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ dependencies = [
'cytoolz>=0.12.1',
'deepdiff>=5.6.0',
'devtools>=0.6',
'diskcache>=5.6.3',
'factory-boy>=3.3.0',
'frozendict>=2.3',
'gridtools-cpp>=2.3.6,==2.*',
'gridtools-cpp>=2.3.7,==2.*',
"importlib-resources>=5.0;python_version<'3.9'",
'jinja2>=3.0.0',
'lark>=1.1.2',
Expand Down
23 changes: 12 additions & 11 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypo
babel==2.16.0 # via -c constraints.txt, sphinx
backcall==0.2.0 # via -c constraints.txt, ipython
black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml)
boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml)
boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml)
bracex==2.5.post1 # via -c constraints.txt, wcmatch
build==1.2.2.post1 # via -c constraints.txt, pip-tools
bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in
bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in
cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
cachetools==5.5.0 # via -c constraints.txt, tox
certifi==2024.8.30 # via -c constraints.txt, requests
Expand All @@ -25,7 +25,7 @@ chardet==5.2.0 # via -c constraints.txt, tox
charset-normalizer==3.4.0 # via -c constraints.txt, requests
clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml)
click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click
cmake==3.30.5 # via -c constraints.txt, gt4py (pyproject.toml)
cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in
colorama==0.4.6 # via -c constraints.txt, tox
comm==0.2.2 # via -c constraints.txt, ipykernel
Expand All @@ -35,11 +35,12 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib
cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml)
dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml)
darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in
debugpy==1.8.7 # via -c constraints.txt, ipykernel
debugpy==1.8.8 # via -c constraints.txt, ipykernel
decorator==5.1.1 # via -c constraints.txt, ipython
deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml)
dill==0.3.9 # via -c constraints.txt, dace
diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml)
distlib==0.3.9 # via -c constraints.txt, virtualenv
docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme
exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest
Expand All @@ -54,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace
frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml)
gitdb==4.0.11 # via -c constraints.txt, gitpython
gitpython==3.1.43 # via -c constraints.txt, tach
gridtools-cpp==2.3.6 # via -c constraints.txt, gt4py (pyproject.toml)
gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml)
hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml)
identify==2.6.1 # via -c constraints.txt, pre-commit
idna==3.10 # via -c constraints.txt, requests
Expand All @@ -65,7 +66,7 @@ inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy
iniconfig==2.0.0 # via -c constraints.txt, pytest
ipykernel==6.29.5 # via -c constraints.txt, nbmake
ipython==8.12.3 # via -c constraints.txt, ipykernel
jedi==0.19.1 # via -c constraints.txt, ipython
jedi==0.19.2 # via -c constraints.txt, ipython
jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx
jsonschema==4.23.0 # via -c constraints.txt, nbformat
jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema
Expand Down Expand Up @@ -94,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml)
nodeenv==1.9.1 # via -c constraints.txt, pre-commit
numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib
orderly-set==5.2.2 # via -c constraints.txt, deepdiff
packaging==24.1 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox
packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox
parso==0.8.4 # via -c constraints.txt, jedi
pathspec==0.12.1 # via -c constraints.txt, black
pexpect==4.9.0 # via -c constraints.txt, ipython
Expand Down Expand Up @@ -135,10 +136,10 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client
questionary==2.0.1 # via -c constraints.txt, bump-my-version
referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications
requests==2.32.3 # via -c constraints.txt, sphinx
rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach
rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach
rich-click==1.8.3 # via -c constraints.txt, bump-my-version
rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing
ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in
ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in
setuptools-scm==8.1.0 # via -c constraints.txt, fparser
six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil
smmap==5.0.1 # via -c constraints.txt, gitdb
Expand All @@ -157,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython
stdlib-list==0.10.0 # via -c constraints.txt, tach
sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml)
tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in
tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in
tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via -c constraints.txt, tach
tomlkit==0.13.2 # via -c constraints.txt, bump-my-version
Expand All @@ -172,7 +173,7 @@ virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox
wcmatch==10.0 # via -c constraints.txt, bump-my-version
wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit
websockets==13.1 # via -c constraints.txt, dace
wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools
wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools
xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml)
zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources

Expand Down
53 changes: 11 additions & 42 deletions src/gt4py/cartesian/gtc/gtir_to_oir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass, field
from typing import Any, List, Optional, Set, Union
from typing import Any, List, Set, Union

from gt4py import eve
from gt4py.cartesian.gtc import common, gtir, oir, utils
from gt4py.cartesian.gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator
from gt4py.cartesian.gtc import gtir, oir, utils
from gt4py.cartesian.gtc.common import CartesianOffset, DataType, UnaryOperator
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents


Expand Down Expand Up @@ -118,15 +118,8 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall:
)

# --- Statements ---
def visit_ParAssignStmt(
self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any
) -> Union[oir.AssignStmt, oir.MaskStmt]:
statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))
if mask is None:
return statement

# Wrap inside MaskStmt
return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc)
def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> oir.AssignStmt:
return oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))

def visit_HorizontalRestriction(
self, node: gtir.HorizontalRestriction, **kwargs: Any
Expand All @@ -138,24 +131,19 @@ def visit_HorizontalRestriction(

return oir.HorizontalRestriction(mask=node.mask, body=body)

def visit_While(
self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any
) -> oir.While:
def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While:
body: List[oir.Stmt] = []
for statement in node.body:
oir_statement = self.visit(statement, **kwargs)
body.extend(utils.flatten_list(utils.listify(oir_statement)))

condition: oir.Expr = self.visit(node.cond)
if mask:
condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition)
return oir.While(cond=condition, body=body, loc=node.loc)

def visit_FieldIfStmt(
self,
node: gtir.FieldIfStmt,
*,
mask: Optional[oir.Expr] = None,
ctx: Context,
**kwargs: Any,
) -> List[Union[oir.AssignStmt, oir.MaskStmt]]:
Expand All @@ -182,26 +170,17 @@ def visit_FieldIfStmt(
loc=node.loc,
)

combined_mask: oir.Expr = condition
if mask:
combined_mask = oir.BinaryOp(
op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc
)
body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body]
)
statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc))
statements.append(oir.MaskStmt(body=body, mask=condition, loc=node.loc))

if node.false_branch:
combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition)
if mask:
combined_mask = oir.BinaryOp(
op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc
)
negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc)
body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body]
)
statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc))
statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc))

return statements

Expand All @@ -211,31 +190,21 @@ def visit_ScalarIfStmt(
self,
node: gtir.ScalarIfStmt,
*,
mask: Optional[oir.Expr] = None,
ctx: Context,
**kwargs: Any,
) -> List[oir.MaskStmt]:
condition = self.visit(node.cond)
combined_mask = condition
if mask:
combined_mask = oir.BinaryOp(
op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc
)

body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body]
)
statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)]

if node.false_branch:
combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc)
if mask:
combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask)

negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc)
body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body]
)
statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc))
statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc))

return statements

Expand Down
Loading

0 comments on commit 7b47e97

Please sign in to comment.