From 048d4c0b68b74a388fe2e69604cb726bfe4f4c83 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Wed, 4 Oct 2023 18:59:07 +0200 Subject: [PATCH 01/16] Extend astype() for tuples --- src/gt4py/next/ffront/fbuiltins.py | 6 ++++- .../ffront/foast_passes/type_deduction.py | 21 +++++++++++++---- src/gt4py/next/ffront/foast_to_itir.py | 22 +++++++++++++----- .../test_type_alias_replacement.py | 23 +++++++++++++++++++ 4 files changed, 61 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 52aae34b3f..ea79f3d8fd 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -175,7 +175,11 @@ def where( @builtin_function -def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: +def astype( + field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + type_: type, + /, +) -> Field | Tuple[Field, ...]: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 605b83a5f0..903c871e13 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: + return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( - value.type, (ts.FieldType, ts.ScalarType) + value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ kind.name for kind in ts.ScalarKind ]: @@ -835,9 +837,20 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) - ) + if isinstance(value, foast.TupleExpr): + element_types_new = [] + for element in value.elts: + element_types_new.append( + with_altered_scalar_kind( + element.type, getattr(ts.ScalarKind, new_type.id.upper()) + ) + ) + return_type = ts.TupleType(types=cast(list[ts.DataType], element_types_new)) + + else: + return_type = with_altered_scalar_kind( + value.type, getattr(ts.ScalarKind, new_type.id.upper()) + ) return foast.Call( func=node.func, diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..3e6d5b3c4e 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -318,12 +318,22 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, dtype = node.args[0], node.args[1].id - - # TODO check that we test astype that results in a itir.map_ operation - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, - ) + if isinstance(obj, foast.TupleExpr): + casted_elements = [] + for _, element in enumerate(obj.elts): + casted_element = self._map( + im.lambda_("it")(im.call("cast_")("it", str(dtype))), element + ) + casted_elements.append(casted_element) + args = [f"__arg{i}" for i in range(len(casted_elements))] + return im.lift(im.lambda_(*args)(im.make_tuple(*[im.deref(arg) for arg in args])))( + *casted_elements + ) + else: + return self._map( + im.lambda_("it")(im.call("cast_")("it", str(dtype))), + obj, + ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args) diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py index e87f869352..3be72bdb33 100644 --- a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -19,9 +19,13 @@ import pytest import gt4py.next as gtx +from gt4py.eve import SymbolRef from gt4py.next import float32, float64 from gt4py.next.ffront.fbuiltins import astype +from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.type_system import type_specifications as ts TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. @@ -42,3 +46,22 @@ def fieldop_with_typealias( foast_tree.body.stmts[0].value.left.func.id == expected and foast_tree.body.stmts[0].value.right.args[1].id == expected ) + + +def test_type_alias_replacement_astype_with_tuples(): + def fieldop_with_typealias_with_tuples( + a: gtx.Field[[TDim], vpfloat], b: gtx.Field[[TDim], vpfloat] + ) -> tuple[gtx.Field[[TDim], wpfloat], gtx.Field[[TDim], wpfloat]]: + return astype((a, b), wpfloat) + + parsed = FieldOperatorParser.apply_to_function(fieldop_with_typealias_with_tuples) + lowered = FieldOperatorLowering.apply(parsed) + + # Check that the type of the first arg of "astype" is a tuple + assert isinstance(parsed.body.stmts[0].value.args[0].type, ts.TupleType) + # Check that the return type of "astype" is a tuple + assert isinstance(parsed.body.stmts[0].value.type, ts.TupleType) + # Check inside the lift function that make_tuple is applied to return a tuple + assert lowered.expr.fun.args[0].expr.fun == itir.SymRef(id=SymbolRef("make_tuple")) + # Check that the elements that form the tuple called the cast_ function individually + assert lowered.expr.args[0].fun.args[0].expr.fun.expr.fun == itir.SymRef(id=SymbolRef("cast_")) From 0898115e8a03ed8574c26ebb27f3a27ab4a0c2e9 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Thu, 19 Oct 2023 22:57:35 +0200 Subject: [PATCH 02/16] Adapt existing test for arg types of astype() --- .../feature_tests/ffront_tests/test_type_deduction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 7800a30e41..30465137e1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]): def test_astype_wrong_value_type(): def simple_astype(a: Field[[TDim], float64]): - # we just use a tuple here but anything that is not a field or scalar works - return astype((1, 2), bool) + # we just use broadcast here but anything that is not a field, scalar or tuple works + return astype(broadcast, bool) with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(simple_astype) From 30174f63795edfbbad2cb17592adc52874af0919 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 12:00:36 +0200 Subject: [PATCH 03/16] Adress requested style change --- src/gt4py/next/ffront/foast_to_itir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 3e6d5b3c4e..ffc6b8f0f8 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -320,7 +320,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: obj, dtype = node.args[0], node.args[1].id if isinstance(obj, foast.TupleExpr): casted_elements = [] - for _, element in enumerate(obj.elts): + for element in obj.elts: casted_element = self._map( im.lambda_("it")(im.call("cast_")("it", str(dtype))), element ) From 1528b9c1247b7dfe624dd2f66b432ddd59dab41e Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 13:05:42 +0200 Subject: [PATCH 04/16] Add extra type check --- src/gt4py/next/ffront/foast_passes/type_deduction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 903c871e13..c18320a3cf 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -840,6 +840,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: if isinstance(value, foast.TupleExpr): element_types_new = [] for element in value.elts: + assert isinstance(element.type, (ts.FieldType, ts.ScalarType)) element_types_new.append( with_altered_scalar_kind( element.type, getattr(ts.ScalarKind, new_type.id.upper()) From ecb597794eded1fd8f765af99f48d907db5979e9 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 22:16:25 +0200 Subject: [PATCH 05/16] Use apply_to_primitive_constituents function on (nested) tuples --- .../next/ffront/foast_passes/type_deduction.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index c18320a3cf..05151f7979 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -838,15 +838,12 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: ) if isinstance(value, foast.TupleExpr): - element_types_new = [] - for element in value.elts: - assert isinstance(element.type, (ts.FieldType, ts.ScalarType)) - element_types_new.append( - with_altered_scalar_kind( - element.type, getattr(ts.ScalarKind, new_type.id.upper()) - ) - ) - return_type = ts.TupleType(types=cast(list[ts.DataType], element_types_new)) + return_type = type_info.apply_to_primitive_constituents( + value.type, + lambda primitive_type: with_altered_scalar_kind( + primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + ), + ) else: return_type = with_altered_scalar_kind( From fe81009cd463baec7bc8361609ec27ffba47b61b Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 22:19:47 +0200 Subject: [PATCH 06/16] Adress 'nitpicking' change --- .../feature_tests/ffront_tests/test_type_deduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 30465137e1..dfa710e038 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -785,7 +785,7 @@ def simple_astype(a: Field[[TDim], float64]): def test_astype_wrong_value_type(): def simple_astype(a: Field[[TDim], float64]): - # we just use broadcast here but anything that is not a field, scalar or tuple works + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works return astype(broadcast, bool) with pytest.raises(errors.DSLError) as exc_info: From 7bb55eaf6c243bb505956c2b49ee911f4f178984 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 22:21:33 +0200 Subject: [PATCH 07/16] Remove previous test and add integration test for casting (nested) tuples --- .../ffront_tests/test_execution.py | 27 +++++++++++++++++++ .../test_type_alias_replacement.py | 23 ---------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 865950eeab..ae6e1451bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -327,6 +327,33 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +def test_astype_on_tuples(cartesian_case): + @gtx.field_operator + def testee( + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]: + return astype((a, b), int64) + + cases.verify_with_default_data( + cartesian_case, testee, ref=lambda a, b: (a.astype(int64), b.astype(int64)) + ) + + +def test_astype_on_nested_tuples(cartesian_case): + @gtx.field_operator + def cast_nested_tuple( + a: cases.IField, b: cases.IField + ) -> tuple[gtx.Field[[IDim], int64], tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]]: + return astype((a, (a, b)), int64) + + @gtx.field_operator + def combine(a: cases.IField, b: cases.IField) -> gtx.Field[[IDim], int64]: + nested_tuple = cast_nested_tuple(a, b) + return nested_tuple[0] + nested_tuple[1][0] + nested_tuple[1][1] + + cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) + + def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py index 3be72bdb33..e87f869352 100644 --- a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -19,13 +19,9 @@ import pytest import gt4py.next as gtx -from gt4py.eve import SymbolRef from gt4py.next import float32, float64 from gt4py.next.ffront.fbuiltins import astype -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir, ir_makers as im -from gt4py.next.type_system import type_specifications as ts TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. @@ -46,22 +42,3 @@ def fieldop_with_typealias( foast_tree.body.stmts[0].value.left.func.id == expected and foast_tree.body.stmts[0].value.right.args[1].id == expected ) - - -def test_type_alias_replacement_astype_with_tuples(): - def fieldop_with_typealias_with_tuples( - a: gtx.Field[[TDim], vpfloat], b: gtx.Field[[TDim], vpfloat] - ) -> tuple[gtx.Field[[TDim], wpfloat], gtx.Field[[TDim], wpfloat]]: - return astype((a, b), wpfloat) - - parsed = FieldOperatorParser.apply_to_function(fieldop_with_typealias_with_tuples) - lowered = FieldOperatorLowering.apply(parsed) - - # Check that the type of the first arg of "astype" is a tuple - assert isinstance(parsed.body.stmts[0].value.args[0].type, ts.TupleType) - # Check that the return type of "astype" is a tuple - assert isinstance(parsed.body.stmts[0].value.type, ts.TupleType) - # Check inside the lift function that make_tuple is applied to return a tuple - assert lowered.expr.fun.args[0].expr.fun == itir.SymRef(id=SymbolRef("make_tuple")) - # Check that the elements that form the tuple called the cast_ function individually - assert lowered.expr.args[0].fun.args[0].expr.fun.expr.fun == itir.SymRef(id=SymbolRef("cast_")) From 926b07efc2eb9b707ae351370b611dfd53afce5b Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 22:23:20 +0200 Subject: [PATCH 08/16] Adapt visit_astype method with recursive func for nested tuples --- src/gt4py/next/ffront/foast_to_itir.py | 31 +++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index ffc6b8f0f8..f7725f9ea4 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -318,22 +318,23 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, dtype = node.args[0], node.args[1].id - if isinstance(obj, foast.TupleExpr): - casted_elements = [] - for element in obj.elts: - casted_element = self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), element + + def recursive_cast(obj, dtype): + if isinstance(obj, foast.TupleExpr): + casted_elements = [] + + for element in obj.elts: + casted_element = recursive_cast(element, dtype) + casted_elements.append(casted_element) + + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *casted_elements ) - casted_elements.append(casted_element) - args = [f"__arg{i}" for i in range(len(casted_elements))] - return im.lift(im.lambda_(*args)(im.make_tuple(*[im.deref(arg) for arg in args])))( - *casted_elements - ) - else: - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, - ) + + else: + return self._map(im.lambda_("it")(im.call("cast_")("it", str(dtype))), obj) + + return recursive_cast(obj, dtype) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args) From 670305226b7af8202bd614713a5b6e2ef466e721 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 20 Oct 2023 23:36:04 +0200 Subject: [PATCH 09/16] Fix integration test --- .../feature_tests/ffront_tests/test_execution.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index ae6e1451bd..f6a1142315 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -329,14 +329,17 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: def test_astype_on_tuples(cartesian_case): @gtx.field_operator - def testee( + def cast_tuple( a: cases.IFloatField, b: cases.IFloatField ) -> tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]: return astype((a, b), int64) - cases.verify_with_default_data( - cartesian_case, testee, ref=lambda a, b: (a.astype(int64), b.astype(int64)) - ) + @gtx.field_operator + def combine(a: cases.IFloatField, b: cases.IFloatField) -> gtx.Field[[IDim], int64]: + packed_tuple = cast_tuple(a, b) + return packed_tuple[0] + packed_tuple[1] + + cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + b) def test_astype_on_nested_tuples(cartesian_case): From 522c1a947e4735b422aa30427a2bd046187410f8 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Thu, 26 Oct 2023 15:21:04 +0200 Subject: [PATCH 10/16] Call 'with_altered_scalar_kind' only once --- .../next/ffront/foast_passes/type_deduction.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 05151f7979..95c9128f87 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -837,18 +837,12 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - if isinstance(value, foast.TupleExpr): - return_type = type_info.apply_to_primitive_constituents( - value.type, - lambda primitive_type: with_altered_scalar_kind( - primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) - ), - ) - - else: - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) - ) + return_type = type_info.apply_to_primitive_constituents( + value.type, + lambda primitive_type: with_altered_scalar_kind( + primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + ), + ) return foast.Call( func=node.func, From 61d20765c27de0b761448ef30721fd315fbf6f2b Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Thu, 26 Oct 2023 15:23:56 +0200 Subject: [PATCH 11/16] Recursive 'process_elements' func to apply a func on the elts of a tuple --- src/gt4py/next/ffront/foast_to_itir.py | 44 ++++++++++++++++---------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index f7725f9ea4..e5280d1eac 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -319,22 +319,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, dtype = node.args[0], node.args[1].id - def recursive_cast(obj, dtype): - if isinstance(obj, foast.TupleExpr): - casted_elements = [] - - for element in obj.elts: - casted_element = recursive_cast(element, dtype) - casted_elements.append(casted_element) - - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( - *casted_elements - ) - - else: - return self._map(im.lambda_("it")(im.call("cast_")("it", str(dtype))), obj) - - return recursive_cast(obj, dtype) + return self._process_elements(obj, lambda x: im.call("cast_")(x, str(dtype)), **kwargs) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args) @@ -414,6 +399,33 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + def _process_elements(self, obj, process_func, **kwargs): + """ + Recursively applies a processing function to the elements of a structured object, preserving its structure. + + Args: + obj: The structured object to be processed. + process_func: A function to apply to each element. + **kwargs: Additional keyword arguments for the processing function. + + Example: + result = process_elements(obj, lambda x: im.call("cast_")(x, str(dtype)), **kwargs) + + Returns: + Structured object with the processing function applied to its elements. + """ + if isinstance(obj.type, ts.TupleType): + if isinstance(obj, foast.Name) or isinstance(obj, foast.Call): + return self.visit(obj) + processed_elements = [ + self._process_elements(el, process_func, **kwargs) for el in obj.elts + ] + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[el for el in processed_elements] + ) + else: + return self._map(im.lambda_("it")(process_func("it")), obj) + class FieldOperatorLoweringError(Exception): ... From bffe3b4c7a7737ec5190600f3aca1a4b3e18929f Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Thu, 26 Oct 2023 16:10:35 +0200 Subject: [PATCH 12/16] Fix execution tests --- .../ffront_tests/test_execution.py | 63 +++++++++++++------ 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f6a1142315..12c4d5416f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -327,34 +327,59 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) -def test_astype_on_tuples(cartesian_case): +@pytest.mark.uses_tuple_returns +def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator - def cast_tuple( + def field_op_returning_a_tuple( a: cases.IFloatField, b: cases.IFloatField - ) -> tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]: - return astype((a, b), int64) + ) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]: + tup = (a, b) + return tup @gtx.field_operator - def combine(a: cases.IFloatField, b: cases.IFloatField) -> gtx.Field[[IDim], int64]: - packed_tuple = cast_tuple(a, b) - return packed_tuple[0] + packed_tuple[1] - - cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + b) - + def cast_tuple( + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], int32], gtx.Field[[IDim], int32]]: + return astype(field_op_returning_a_tuple(a, b), int32) -def test_astype_on_nested_tuples(cartesian_case): @gtx.field_operator def cast_nested_tuple( - a: cases.IField, b: cases.IField - ) -> tuple[gtx.Field[[IDim], int64], tuple[gtx.Field[[IDim], int64], gtx.Field[[IDim], int64]]]: - return astype((a, (a, b)), int64) + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], int32], tuple[gtx.Field[[IDim], int32], gtx.Field[[IDim], int32]]]: + return astype((a, field_op_returning_a_tuple(a, b)), int32) + + a = cases.allocate(cartesian_case, cast_tuple, "a")() + b = cases.allocate(cartesian_case, cast_tuple, "b")() + out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() + out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() + + def unpack_and_compare(ref, out): + if isinstance(ref, tuple) and isinstance(out, tuple): + return all( + unpack_and_compare(ref_item, out_item) for ref_item, out_item in zip(ref, out) + ) + else: + return ref.dtype == out.dtype and np.array_equal(ref, out) - @gtx.field_operator - def combine(a: cases.IField, b: cases.IField) -> gtx.Field[[IDim], int64]: - nested_tuple = cast_nested_tuple(a, b) - return nested_tuple[0] + nested_tuple[1][0] + nested_tuple[1][1] + cases.verify( + cartesian_case, + cast_tuple, + a, + b, + out=out_tuple, + ref=(int32(a), int32(b)), + comparison=lambda ref, out: unpack_and_compare(ref, out), + ) - cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) + cases.verify( + cartesian_case, + cast_nested_tuple, + a, + b, + out=out_nested_tuple, + ref=(int32(a), (int32(a), int32(b))), + comparison=lambda ref, out: unpack_and_compare(ref, out), + ) def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures From 726d9a7b8a47a71cce83cfcb6c6ea2dea9c976b7 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 10 Nov 2023 10:31:05 +0100 Subject: [PATCH 13/16] Adapt visit_astype for foast.Call and foast.Name --- src/gt4py/next/ffront/foast_to_itir.py | 50 ++++++++++++-------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index e5280d1eac..ebe16c39e7 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -317,9 +317,10 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, dtype = node.args[0], node.args[1].id - - return self._process_elements(obj, lambda x: im.call("cast_")(x, str(dtype)), **kwargs) + obj, new_type = node.args[0], node.args[1].id + return self._process_elements( + obj, obj.type, lambda x: im.call("cast_")(x, str(new_type)), "expr", **kwargs + ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args) @@ -399,32 +400,27 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - def _process_elements(self, obj, process_func, **kwargs): - """ - Recursively applies a processing function to the elements of a structured object, preserving its structure. - - Args: - obj: The structured object to be processed. - process_func: A function to apply to each element. - **kwargs: Additional keyword arguments for the processing function. - - Example: - result = process_elements(obj, lambda x: im.call("cast_")(x, str(dtype)), **kwargs) - - Returns: - Structured object with the processing function applied to its elements. - """ - if isinstance(obj.type, ts.TupleType): + def _process_elements(self, obj, obj_type, process_func, expr, **kwargs): + """Recursively applies a processing function to the elements of a structured object, preserving its structure.""" + if isinstance(obj_type, ts.TupleType): if isinstance(obj, foast.Name) or isinstance(obj, foast.Call): - return self.visit(obj) - processed_elements = [ - self._process_elements(el, process_func, **kwargs) for el in obj.elts - ] - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( - *[el for el in processed_elements] - ) + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements( + obj, obj_type.types[i], process_func, im.tuple_get(i, expr) + ) + for i in range(len(obj.type.types)) + ] + ) + elif isinstance(obj, foast.TupleExpr): + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements(el, el.type, process_func, expr, **kwargs) + for el in obj.elts + ] + ) else: - return self._map(im.lambda_("it")(process_func("it")), obj) + return self._map(im.lambda_("expr")(process_func(expr)), obj) class FieldOperatorLoweringError(Exception): From bf163e0fc91eadf329269e30c23f66999a32d0cc Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Fri, 10 Nov 2023 10:31:38 +0100 Subject: [PATCH 14/16] Fix tests --- .../ffront_tests/test_execution.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f934994aae..97729c0f30 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -336,37 +336,51 @@ def field_op_returning_a_tuple( @gtx.field_operator def cast_tuple( - a: cases.IFloatField, b: cases.IFloatField - ) -> tuple[gtx.Field[[IDim], int32], gtx.Field[[IDim], int32]]: - return astype(field_op_returning_a_tuple(a, b), int32) + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype(field_op_returning_a_tuple(a, b), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1] == b_casted_to_int_outside_of_gt4py, + ) @gtx.field_operator def cast_nested_tuple( - a: cases.IFloatField, b: cases.IFloatField - ) -> tuple[gtx.Field[[IDim], int32], tuple[gtx.Field[[IDim], int32], gtx.Field[[IDim], int32]]]: - return astype((a, field_op_returning_a_tuple(a, b)), int32) + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype((a, field_op_returning_a_tuple(a, b)), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1][0] == a_casted_to_int_outside_of_gt4py, + result[1][1] == b_casted_to_int_outside_of_gt4py, + ) a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() + a_casted_to_int_outside_of_gt4py = cases.allocate( + cartesian_case, cast_tuple, "a", dtype=int32 + )() + b_casted_to_int_outside_of_gt4py = cases.allocate( + cartesian_case, cast_tuple, "b", dtype=int32 + )() out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() - def unpack_and_compare(ref, out): - if isinstance(ref, tuple) and isinstance(out, tuple): - return all( - unpack_and_compare(ref_item, out_item) for ref_item, out_item in zip(ref, out) - ) - else: - return ref.dtype == out.dtype and np.array_equal(ref, out) - cases.verify( cartesian_case, cast_tuple, a, b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, out=out_tuple, - ref=(int32(a), int32(b)), - comparison=lambda ref, out: unpack_and_compare(ref, out), + ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), ) cases.verify( @@ -374,9 +388,14 @@ def unpack_and_compare(ref, out): cast_nested_tuple, a, b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, out=out_nested_tuple, - ref=(int32(a), (int32(a), int32(b))), - comparison=lambda ref, out: unpack_and_compare(ref, out), + ref=( + np.full_like(a, True, dtype=bool), + np.full_like(a, True, dtype=bool), + np.full_like(b, True, dtype=bool), + ), ) From b3da91f191fa42b30a740bc2b2f6e7d03e324ac5 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Thu, 16 Nov 2023 19:51:58 +0100 Subject: [PATCH 15/16] Rename args and refactor 'process_elements' --- src/gt4py/next/ffront/foast_to_itir.py | 46 ++++++++++++++------------ 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index ebe16c39e7..816b8581f1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -319,7 +319,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = node.args[0], node.args[1].id return self._process_elements( - obj, obj.type, lambda x: im.call("cast_")(x, str(new_type)), "expr", **kwargs + lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -400,27 +400,31 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - def _process_elements(self, obj, obj_type, process_func, expr, **kwargs): - """Recursively applies a processing function to the elements of a structured object, preserving its structure.""" - if isinstance(obj_type, ts.TupleType): - if isinstance(obj, foast.Name) or isinstance(obj, foast.Call): - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( - *[ - self._process_elements( - obj, obj_type.types[i], process_func, im.tuple_get(i, expr) - ) - for i in range(len(obj.type.types)) - ] - ) - elif isinstance(obj, foast.TupleExpr): - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( - *[ - self._process_elements(el, el.type, process_func, expr, **kwargs) - for el in obj.elts - ] - ) + def _process_elements( + self, + process_func: Callable[[itir.Expr], itir.Expr], + obj: foast.Expr, + current_el_type: ts.TypeSpec, + current_el_expr: itir.Expr = im.ref("expr"), + ): + """Recursively applies a processing function to all primitive constituents of a tuple.""" + if isinstance(current_el_type, ts.TupleType): + # TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element. + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements( + process_func, + obj, + current_el_type.types[i], + im.tuple_get(i, current_el_expr), + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") else: - return self._map(im.lambda_("expr")(process_func(expr)), obj) + return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) class FieldOperatorLoweringError(Exception): From 4a11354a862cefcb7f643790dbfd2551dccbb2e5 Mon Sep 17 00:00:00 2001 From: Nina Burgdorfer Date: Thu, 16 Nov 2023 19:52:19 +0100 Subject: [PATCH 16/16] Fix tests --- .../feature_tests/ffront_tests/test_execution.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 29eb95faeb..58181fd7a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -363,12 +363,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_casted_to_int_outside_of_gt4py = cases.allocate( - cartesian_case, cast_tuple, "a", dtype=int32 - )() - b_casted_to_int_outside_of_gt4py = cases.allocate( - cartesian_case, cast_tuple, "b", dtype=int32 - )() + a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) + b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)()