diff --git a/pyproject.toml b/pyproject.toml index a200451b2d..041448e17d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -330,7 +330,6 @@ markers = [ 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', 'uses_applied_shifts: tests that require backend support for applied-shifts', - 'uses_can_deref: tests that require backend support for can_deref', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_if_stmts: tests that require backend support for if-statements', diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 7406a9a25c..51e613ef81 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -286,7 +286,7 @@ def _np_cp_setitem( _nd_array_implementations = [np] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, eq=False) class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np @@ -299,7 +299,7 @@ class NumPyArrayField(NdArrayField): if cp: _nd_array_implementations.append(cp) - @dataclasses.dataclass(frozen=True) + @dataclasses.dataclass(frozen=True, eq=False) class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp @@ -311,7 +311,7 @@ class CuPyArrayField(NdArrayField): if jnp: _nd_array_implementations.append(jnp) - @dataclasses.dataclass(frozen=True) + @dataclasses.dataclass(frozen=True, eq=False) class JaxArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = jnp diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a8a69a0908..107415eb06 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -701,7 +701,7 @@ def __call__( ) else: # "out" -> field_operator called from program in embedded execution - # TODO(egparedes) put offset_provider in ctxt var here when implementing remap + # TODO(egparedes): put offset_provider in ctxt var here when implementing remap domain = kwargs.pop("domain", None) res = self.definition(*args, **kwargs) _tuple_assign_field( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 733626f398..fb99af92e1 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -214,9 +214,16 @@ def where( @BuiltInFunction -def astype(field: common.Field | core_defs.ScalarT, type_: type, /) -> common.Field: - assert core_defs.is_scalar_type(field) - return type_(field) +def astype( + value: Field | core_defs.ScalarT | Tuple, + type_: type, + /, +) -> Field | core_defs.ScalarT | Tuple: + if isinstance(value, tuple): + return tuple(astype(v, type_) for v in value) + # default implementation for scalars, Fields are handled via dispatch + assert core_defs.is_scalar_type(value) + return core_defs.dtype(type_).scalar_type(value) UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"] diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 605b83a5f0..95c9128f87 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: + return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( - value.type, (ts.FieldType, ts.ScalarType) + value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ kind.name for kind in ts.ScalarKind ]: @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) + return_type = type_info.apply_to_primitive_constituents( + value.type, + lambda primitive_type: with_altered_scalar_kind( + primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + ), ) return foast.Call( diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..816b8581f1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, dtype = node.args[0], node.args[1].id - - # TODO check that we test astype that results in a itir.map_ operation - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, + obj, new_type = node.args[0], node.args[1].id + return self._process_elements( + lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + def _process_elements( + self, + process_func: Callable[[itir.Expr], itir.Expr], + obj: foast.Expr, + current_el_type: ts.TypeSpec, + current_el_expr: itir.Expr = im.ref("expr"), + ): + """Recursively applies a processing function to all primitive constituents of a tuple.""" + if isinstance(current_el_type, ts.TupleType): + # TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element. + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements( + process_func, + obj, + current_el_type.types[i], + im.tuple_get(i, current_el_expr), + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") + else: + return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) + class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 9f67cb26da..e3fba87571 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -65,14 +65,29 @@ def convert_arg(arg: Any): return arg -def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]): - program = apply_common_transforms( +def preprocess_program( + program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode +): + node = apply_common_transforms( program, - offset_provider=offset_provider, - lift_mode=LiftMode.FORCE_INLINE, common_subexpression_elimination=False, + lift_mode=lift_mode, + offset_provider=offset_provider, + unroll_reduce=False, ) - return program + # If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG. + # In this case, just retry with unrolled reductions. + if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): + fencil_definition = node + else: + fencil_definition = apply_common_transforms( + program, + common_subexpression_elimination=False, + lift_mode=lift_mode, + offset_provider=offset_provider, + unroll_reduce=True, + ) + return fencil_definition def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: @@ -156,11 +171,14 @@ def get_cache_id( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: # build parameters auto_optimize = kwargs.get("auto_optimize", False) + build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") run_on_gpu = kwargs.get("run_on_gpu", False) - build_cache = kwargs.get("build_cache", None) # ITIR parameters column_axis = kwargs.get("column_axis", None) + lift_mode = ( + LiftMode.FORCE_INLINE + ) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] @@ -173,7 +191,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg = sdfg_program.sdfg else: # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider) + program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 5d47cad909..5b240ea2b7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -265,6 +265,29 @@ def builtin_neighbors( return [ValueExpr(result_access, iterator.dtype)] +def builtin_can_deref( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + # first visit shift, to get set of indices for deref + can_deref_callable = node_args[0] + assert isinstance(can_deref_callable, itir.FunCall) + shift_callable = can_deref_callable.fun + assert isinstance(shift_callable, itir.FunCall) + assert isinstance(shift_callable.fun, itir.SymRef) + assert shift_callable.fun.id == "shift" + iterator = transformer._visit_shift(can_deref_callable) + + # create tasklet to check that field indices are non-negative (-1 is invalid) + args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions] + internals = [f"{arg.value.data}_v" for arg in args] + expr_code = " && ".join([f"{v} >= 0" for v in internals]) + + # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution + return transformer.add_expr_tasklet( + list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + ) + + def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -318,11 +341,12 @@ def builtin_undefined(*args: Any) -> Any: _GENERAL_BUILTIN_MAPPING: dict[ str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { - "make_tuple": builtin_make_tuple, - "tuple_get": builtin_tuple_get, - "if_": builtin_if, + "can_deref": builtin_can_deref, "cast_": builtin_cast, + "if_": builtin_if, + "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, + "tuple_get": builtin_tuple_get, } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 22154da9a7..e391727996 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -333,6 +333,76 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +@pytest.mark.uses_tuple_returns +def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures + @gtx.field_operator + def field_op_returning_a_tuple( + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]: + tup = (a, b) + return tup + + @gtx.field_operator + def cast_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_asint: cases.IField, + b_asint: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype(field_op_returning_a_tuple(a, b), int32) + return ( + result[0] == a_asint, + result[1] == b_asint, + ) + + @gtx.field_operator + def cast_nested_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_asint: cases.IField, + b_asint: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype((a, field_op_returning_a_tuple(a, b)), int32) + return ( + result[0] == a_asint, + result[1][0] == a_asint, + result[1][1] == b_asint, + ) + + a = cases.allocate(cartesian_case, cast_tuple, "a")() + b = cases.allocate(cartesian_case, cast_tuple, "b")() + a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32)) + b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32)) + out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() + out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() + + cases.verify( + cartesian_case, + cast_tuple, + a, + b, + a_asint, + b_asint, + out=out_tuple, + ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), + ) + + cases.verify( + cartesian_case, + cast_nested_tuple, + a, + b, + a_asint, + b_asint, + out=out_nested_tuple, + ref=( + np.full_like(a, True, dtype=bool), + np.full_like(a, True, dtype=bool), + np.full_like(b, True, dtype=bool), + ), + ) + + def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 7800a30e41..dfa710e038 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]): def test_astype_wrong_value_type(): def simple_astype(a: Field[[TDim], float64]): - # we just use a tuple here but anything that is not a field or scalar works - return astype((1, 2), bool) + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works + return astype(broadcast, bool) with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(simple_astype) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index d5d57c9024..2bcd0f8367 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -250,7 +250,6 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) -@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor