diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 560cbf96cf..d36c2e5c4a 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -7,11 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass, field -from typing import Any, List, Optional, Set, Union +from typing import Any, List, Set, Union from gt4py import eve -from gt4py.cartesian.gtc import common, gtir, oir, utils -from gt4py.cartesian.gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator +from gt4py.cartesian.gtc import gtir, oir, utils +from gt4py.cartesian.gtc.common import CartesianOffset, DataType, UnaryOperator from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents @@ -118,15 +118,8 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: ) # --- Statements --- - def visit_ParAssignStmt( - self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> Union[oir.AssignStmt, oir.MaskStmt]: - statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) - if mask is None: - return statement - - # Wrap inside MaskStmt - return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc) + def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> oir.AssignStmt: + return oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any @@ -138,24 +131,19 @@ def visit_HorizontalRestriction( return oir.HorizontalRestriction(mask=node.mask, body=body) - def visit_While( - self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> oir.While: + def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While: body: List[oir.Stmt] = [] for statement in node.body: oir_statement = self.visit(statement, **kwargs) body.extend(utils.flatten_list(utils.listify(oir_statement))) condition: oir.Expr = self.visit(node.cond) - if mask: - condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition) return oir.While(cond=condition, body=body, loc=node.loc) def visit_FieldIfStmt( self, node: gtir.FieldIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: @@ -182,26 +170,17 @@ def visit_FieldIfStmt( loc=node.loc, ) - combined_mask: oir.Expr = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=condition, loc=node.loc)) if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition) - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements @@ -211,31 +190,21 @@ def visit_ScalarIfStmt( self, node: gtir.ScalarIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[oir.MaskStmt]: condition = self.visit(node.cond) - combined_mask = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc - ) - body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)] if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) - if mask: - combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) - + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index ed573ebfff..b6aeb49823 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -157,13 +157,12 @@ def visit_AssignStmt( def visit_While( self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any ) -> npir.While: - cond = self.visit(node.cond, mask=mask, **kwargs) + cond_expr = self.visit(node.cond, **kwargs) if mask: - mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) - else: - mask = cond + cond_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond_expr) + return npir.While( - cond=cond, body=utils.flatten_list(self.visit(node.body, mask=mask, **kwargs)) + cond=cond_expr, body=utils.flatten_list(self.visit(node.body, mask=cond_expr, **kwargs)) ) def visit_HorizontalRestriction( diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index d3a5744389..0312aea7c3 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -444,6 +444,36 @@ def validation(field_a, field_b, field_c, *, factor, domain, origin, **kwargs): field_a += 1 +class TestRuntimeIfNestedWhile(gt_testing.StencilTestSuite): + """Test conditional while statements.""" + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ALL_BACKENDS + symbols = dict( + infield=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + with computation(PARALLEL), interval(...): + if infield < 10: + outfield = 1 + done = False + while not done: + outfield = 2 + done = True + else: + condition = True + while condition: + outfield = 4 + condition = False + outfield = 3 + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = 2 + + class TestTernaryOp(gt_testing.StencilTestSuite): dtypes = (np.float_,) domain_range = [(1, 15), (2, 15), (1, 15)] diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py index 4de7f9f5d6..4877a39503 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py @@ -28,6 +28,7 @@ StencilFactory, VerticalLoopFactory, VerticalLoopSectionFactory, + WhileFactory, ) @@ -78,6 +79,18 @@ def test_mask_stmt_to_assigns() -> None: assert len(assign_stmts) == 1 +def test_mask_stmt_to_while() -> None: + mask_oir = MaskStmtFactory(body=[WhileFactory()]) + statements = OirToNpir().visit(mask_oir, extent=Extent.zeros(ndims=2)) + assert len(statements) == 1 + assert isinstance(statements[0], npir.While) + condition = statements[0].cond + assert isinstance(condition, npir.VectorLogic) + assert condition.op == common.LogicalOperator.AND + mask_npir = OirToNpir().visit(mask_oir.mask) + assert condition.left == mask_npir or condition.right == mask_npir + + def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2))