From 4dadcf47b186783768c448891494f99c2c72ed3c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 28 Oct 2024 10:45:42 +0100 Subject: [PATCH] WIP: fix numpy while loop --- src/gt4py/cartesian/gtc/numpy/oir_to_npir.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index ed573ebfff..b6aeb49823 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -157,13 +157,12 @@ def visit_AssignStmt( def visit_While( self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any ) -> npir.While: - cond = self.visit(node.cond, mask=mask, **kwargs) + cond_expr = self.visit(node.cond, **kwargs) if mask: - mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) - else: - mask = cond + cond_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond_expr) + return npir.While( - cond=cond, body=utils.flatten_list(self.visit(node.body, mask=mask, **kwargs)) + cond=cond_expr, body=utils.flatten_list(self.visit(node.body, mask=cond_expr, **kwargs)) ) def visit_HorizontalRestriction(