Skip to content

Commit

Permalink
fix[cartesian]: While loops inside conditions (#1712)
Browse files Browse the repository at this point in the history
This is a follow-up from PR
#1630. It combines two things

1. In `oir_to_npir`, we fix conditional `while` loops in `numpy`
backend. After PR 1630 these were stuck under certain conditions. Tests
coverage extended.
2. In `gtir_to_oir`, we cleaned up the now unused `mask` parameter,
which was pre-PR 1630 needed to pass down the mask information. With PR
1630 we actually removed the need for that parameter to be passed along
because we properly nest the nested statements.
  • Loading branch information
romanc authored Nov 5, 2024
1 parent 604e377 commit 60bb7b1
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 47 deletions.
53 changes: 11 additions & 42 deletions src/gt4py/cartesian/gtc/gtir_to_oir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass, field
from typing import Any, List, Optional, Set, Union
from typing import Any, List, Set, Union

from gt4py import eve
from gt4py.cartesian.gtc import common, gtir, oir, utils
from gt4py.cartesian.gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator
from gt4py.cartesian.gtc import gtir, oir, utils
from gt4py.cartesian.gtc.common import CartesianOffset, DataType, UnaryOperator
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents


Expand Down Expand Up @@ -118,15 +118,8 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall:
)

# --- Statements ---
def visit_ParAssignStmt(
self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any
) -> Union[oir.AssignStmt, oir.MaskStmt]:
statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))
if mask is None:
return statement

# Wrap inside MaskStmt
return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc)
def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> oir.AssignStmt:
return oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))

def visit_HorizontalRestriction(
self, node: gtir.HorizontalRestriction, **kwargs: Any
Expand All @@ -138,24 +131,19 @@ def visit_HorizontalRestriction(

return oir.HorizontalRestriction(mask=node.mask, body=body)

def visit_While(
self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any
) -> oir.While:
def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While:
body: List[oir.Stmt] = []
for statement in node.body:
oir_statement = self.visit(statement, **kwargs)
body.extend(utils.flatten_list(utils.listify(oir_statement)))

condition: oir.Expr = self.visit(node.cond)
if mask:
condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition)
return oir.While(cond=condition, body=body, loc=node.loc)

def visit_FieldIfStmt(
self,
node: gtir.FieldIfStmt,
*,
mask: Optional[oir.Expr] = None,
ctx: Context,
**kwargs: Any,
) -> List[Union[oir.AssignStmt, oir.MaskStmt]]:
Expand All @@ -182,26 +170,17 @@ def visit_FieldIfStmt(
loc=node.loc,
)

combined_mask: oir.Expr = condition
if mask:
combined_mask = oir.BinaryOp(
op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc
)
body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body]
)
statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc))
statements.append(oir.MaskStmt(body=body, mask=condition, loc=node.loc))

if node.false_branch:
combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition)
if mask:
combined_mask = oir.BinaryOp(
op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc
)
negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc)
body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body]
)
statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc))
statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc))

return statements

Expand All @@ -211,31 +190,21 @@ def visit_ScalarIfStmt(
self,
node: gtir.ScalarIfStmt,
*,
mask: Optional[oir.Expr] = None,
ctx: Context,
**kwargs: Any,
) -> List[oir.MaskStmt]:
condition = self.visit(node.cond)
combined_mask = condition
if mask:
combined_mask = oir.BinaryOp(
op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc
)

body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body]
)
statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)]

if node.false_branch:
combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc)
if mask:
combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask)

negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc)
body = utils.flatten_list(
[self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body]
)
statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc))
statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc))

return statements

Expand Down
9 changes: 4 additions & 5 deletions src/gt4py/cartesian/gtc/numpy/oir_to_npir.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,12 @@ def visit_AssignStmt(
def visit_While(
self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any
) -> npir.While:
cond = self.visit(node.cond, mask=mask, **kwargs)
cond_expr = self.visit(node.cond, **kwargs)
if mask:
mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond)
else:
mask = cond
cond_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond_expr)

return npir.While(
cond=cond, body=utils.flatten_list(self.visit(node.body, mask=mask, **kwargs))
cond=cond_expr, body=utils.flatten_list(self.visit(node.body, mask=cond_expr, **kwargs))
)

def visit_HorizontalRestriction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,36 @@ def validation(field_a, field_b, field_c, *, factor, domain, origin, **kwargs):
field_a += 1


class TestRuntimeIfNestedWhile(gt_testing.StencilTestSuite):
"""Test conditional while statements."""

dtypes = (np.float_,)
domain_range = [(1, 15), (1, 15), (1, 15)]
backends = ALL_BACKENDS
symbols = dict(
infield=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)]),
outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]),
)

def definition(infield, outfield):
with computation(PARALLEL), interval(...):
if infield < 10:
outfield = 1
done = False
while not done:
outfield = 2
done = True
else:
condition = True
while condition:
outfield = 4
condition = False
outfield = 3

def validation(infield, outfield, *, domain, origin, **kwargs):
outfield[...] = 2


class TestTernaryOp(gt_testing.StencilTestSuite):
dtypes = (np.float_,)
domain_range = [(1, 15), (2, 15), (1, 15)]
Expand Down
13 changes: 13 additions & 0 deletions tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
StencilFactory,
VerticalLoopFactory,
VerticalLoopSectionFactory,
WhileFactory,
)


Expand Down Expand Up @@ -78,6 +79,18 @@ def test_mask_stmt_to_assigns() -> None:
assert len(assign_stmts) == 1


def test_mask_stmt_to_while() -> None:
mask_oir = MaskStmtFactory(body=[WhileFactory()])
statements = OirToNpir().visit(mask_oir, extent=Extent.zeros(ndims=2))
assert len(statements) == 1
assert isinstance(statements[0], npir.While)
condition = statements[0].cond
assert isinstance(condition, npir.VectorLogic)
assert condition.op == common.LogicalOperator.AND
mask_npir = OirToNpir().visit(mask_oir.mask)
assert condition.left == mask_npir or condition.right == mask_npir


def test_mask_propagation() -> None:
mask_stmt = MaskStmtFactory()
assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2))
Expand Down

0 comments on commit 60bb7b1

Please sign in to comment.