diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 527197e0bc..ea88948841 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,6 +15,8 @@ from __future__ import annotations import dataclasses +import functools +import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar @@ -39,40 +41,38 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] -def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_unary_op(a: NdArrayField) -> common.Field: - xp = a.__class__.array_ns +def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]: + def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: + first = fields[0] + assert isinstance(first, NdArrayField) + xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - new_data = op(a.ndarray) - return a.__class__.from_array(new_data, domain=a.domain) - - _builtin_unary_op.__name__ = builtin_name - return _builtin_unary_op - - -def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_binary_op(a: NdArrayField, b: common.Field) -> common.Field: - xp = a.__class__.array_ns - op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): - if not a.domain == b.domain: - domain_intersection = a.domain & b.domain - a_broadcasted = _broadcast(a, domain_intersection.dims) - b_broadcasted = _broadcast(b, domain_intersection.dims) - a_slices = _get_slices_from_domain_slice(a_broadcasted.domain, domain_intersection) - b_slices = _get_slices_from_domain_slice(b_broadcasted.domain, domain_intersection) - new_data = op(a_broadcasted.ndarray[a_slices], b_broadcasted.ndarray[b_slices]) - return a.__class__.from_array(new_data, domain=domain_intersection) - new_data = op(a.ndarray, xp.asarray(b.ndarray)) - else: - assert isinstance(b, core_defs.SCALAR_TYPES) - new_data = op(a.ndarray, b) - - return a.__class__.from_array(new_data, domain=a.domain) - - _builtin_binary_op.__name__ = builtin_name - return _builtin_binary_op + domain_intersection = functools.reduce( + operator.and_, + [f.domain for f in fields if common.is_field(f)], + common.Domain(dims=tuple(), ranges=tuple()), + ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] + for f in fields: + if common.is_field(f): + if f.domain == domain_intersection: + transformed.append(xp.asarray(f.ndarray)) + else: + f_broadcasted = _broadcast(f, domain_intersection.dims) + f_slices = _get_slices_from_domain_slice( + f_broadcasted.domain, domain_intersection + ) + transformed.append(xp.asarray(f_broadcasted.ndarray[f_slices])) + else: + assert core_defs.is_scalar_type(f) + transformed.append(f) + + new_data = op(*transformed) + return first.__class__.from_array(new_data, domain=domain_intersection) + + _builtin_op.__name__ = builtin_name + return _builtin_op _Value: TypeAlias = common.Field | core_defs.ScalarT @@ -174,56 +174,50 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __call__ = None # type: ignore[assignment] # TODO: remap - __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") + __abs__ = _make_builtin("abs", "abs") - __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __neg__ = _make_builtin("neg", "negative") - __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_builtin("add", "add") - __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") + __pos__ = _make_builtin("pos", "positive") - __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") + __sub__ = __rsub__ = _make_builtin("sub", "subtract") - __mul__ = __rmul__ = _make_binary_array_field_intrinsic_func("mul", "multiply") + __mul__ = __rmul__ = _make_builtin("mul", "multiply") - __truediv__ = __rtruediv__ = _make_binary_array_field_intrinsic_func("div", "divide") + __truediv__ = __rtruediv__ = _make_builtin("div", "divide") - __floordiv__ = __rfloordiv__ = _make_binary_array_field_intrinsic_func( - "floordiv", "floor_divide" - ) + __floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide") - __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") + __pow__ = _make_builtin("pow", "power") - __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") + __mod__ = __rmod__ = _make_builtin("mod", "mod") def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( - self, other - ) + return _make_builtin("logical_and", "logical_and")(self, other) raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") __rand__ = __and__ def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + return _make_builtin("logical_or", "logical_or")(self, other) raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") __ror__ = __or__ def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( - self, other - ) + return _make_builtin("logical_xor", "logical_xor")(self, other) raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") __rxor__ = __xor__ def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + return _make_builtin("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") def _slice( @@ -241,7 +235,7 @@ def _slice( return new_domain, slice_ -# -- Specialized implementations for intrinsic operations on array fields -- +# -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] @@ -254,19 +248,18 @@ def _slice( ): if name in ["abs", "power", "gamma"]: continue - NdArrayField.register_builtin_func( - getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name) - ) + NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name)) NdArrayField.register_builtin_func( - fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined] + fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined] + fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] + fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined] ) +NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) def _np_cp_setitem( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 52aae34b3f..13c21eb516 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -132,6 +132,27 @@ def builtin_function(fun: Callable[_P, _R]) -> BuiltInFunction[_R, _P]: return BuiltInFunction(fun) +MaskT = TypeVar("MaskT", bound=Field) +FieldT = TypeVar("FieldT", bound=Union[Field, gt4py_defs.Scalar, Tuple]) + + +class WhereBuiltinFunction( + BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] +): + def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: + if isinstance(true_field, tuple) or isinstance(false_field, tuple): + if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): + raise ValueError( + f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages + ) + if len(true_field) != len(false_field): + raise ValueError( + "Tuple of different size not allowed." + ) # TODO(havogt) find a strategy to unify parsing and embedded error messages + return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return super().__call__(mask, true_field, false_field) + + @builtin_function def neighbor_sum( field: Field, @@ -164,7 +185,7 @@ def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /) raise NotImplementedError() -@builtin_function +@WhereBuiltinFunction def where( mask: Field, true_field: Field | gt4py_defs.ScalarT | Tuple, diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 8a4b4cbd84..49aeece87e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -98,6 +98,60 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) +def test_where_builtin(nd_array_implementation): + cond = np.asarray([True, False]) + true_ = np.asarray([1.0, 2.0], dtype=np.float32) + false_ = np.asarray([3.0, 4.0], dtype=np.float32) + + field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]] + expected = np.where(cond, true_, false_) + + result = fbuiltins.where(*field_inputs) + assert np.allclose(result.ndarray, expected) + + +def test_where_builtin_different_domain(nd_array_implementation): + cond = np.asarray([True, False]) + true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32) + + cond_field = common.field( + nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2}) + ) + true_field = common.field( + nd_array_implementation.asarray(true_), + domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}), + ) + false_field = common.field( + nd_array_implementation.asarray(false_), + domain=common.domain({JDim: common.UnitRange(-1, 3)}), + ) + + expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1]) + + result = fbuiltins.where(cond_field, true_field, false_field) + assert np.allclose(result.ndarray, expected) + + +def test_where_builtin_with_tuple(nd_array_implementation): + cond = np.asarray([True, False]) + true0 = np.asarray([1.0, 2.0], dtype=np.float32) + false0 = np.asarray([3.0, 4.0], dtype=np.float32) + true1 = np.asarray([11.0, 12.0], dtype=np.float32) + false1 = np.asarray([13.0, 14.0], dtype=np.float32) + + expected0 = np.where(cond, true0, false0) + expected1 = np.where(cond, true1, false1) + + cond_field = _make_field(cond, nd_array_implementation, dtype=bool) + field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1]) + field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1]) + + result = fbuiltins.where(cond_field, field_true, field_false) + assert np.allclose(result[0].ndarray, expected0) + assert np.allclose(result[1].ndarray, expected1) + + def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0]