Skip to content

Commit

Permalink
Merge branch 'GridTools:main' into dace-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao authored Nov 16, 2023
2 parents 86b5d8f + b8cda74 commit 0ba4fd5
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 60 deletions.
111 changes: 52 additions & 59 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from __future__ import annotations

import dataclasses
import functools
import operator
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar
Expand All @@ -39,40 +41,38 @@
jnp: Optional[ModuleType] = None # type:ignore[no-redef]


def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_unary_op(a: NdArrayField) -> common.Field:
xp = a.__class__.array_ns
def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
first = fields[0]
assert isinstance(first, NdArrayField)
xp = first.__class__.array_ns
op = getattr(xp, array_builtin_name)
new_data = op(a.ndarray)

return a.__class__.from_array(new_data, domain=a.domain)

_builtin_unary_op.__name__ = builtin_name
return _builtin_unary_op


def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_binary_op(a: NdArrayField, b: common.Field) -> common.Field:
xp = a.__class__.array_ns
op = getattr(xp, array_builtin_name)
if hasattr(b, "__gt_builtin_func__"): # common.is_field(b):
if not a.domain == b.domain:
domain_intersection = a.domain & b.domain
a_broadcasted = _broadcast(a, domain_intersection.dims)
b_broadcasted = _broadcast(b, domain_intersection.dims)
a_slices = _get_slices_from_domain_slice(a_broadcasted.domain, domain_intersection)
b_slices = _get_slices_from_domain_slice(b_broadcasted.domain, domain_intersection)
new_data = op(a_broadcasted.ndarray[a_slices], b_broadcasted.ndarray[b_slices])
return a.__class__.from_array(new_data, domain=domain_intersection)
new_data = op(a.ndarray, xp.asarray(b.ndarray))
else:
assert isinstance(b, core_defs.SCALAR_TYPES)
new_data = op(a.ndarray, b)

return a.__class__.from_array(new_data, domain=a.domain)

_builtin_binary_op.__name__ = builtin_name
return _builtin_binary_op
domain_intersection = functools.reduce(
operator.and_,
[f.domain for f in fields if common.is_field(f)],
common.Domain(dims=tuple(), ranges=tuple()),
)
transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = []
for f in fields:
if common.is_field(f):
if f.domain == domain_intersection:
transformed.append(xp.asarray(f.ndarray))
else:
f_broadcasted = _broadcast(f, domain_intersection.dims)
f_slices = _get_slices_from_domain_slice(
f_broadcasted.domain, domain_intersection
)
transformed.append(xp.asarray(f_broadcasted.ndarray[f_slices]))
else:
assert core_defs.is_scalar_type(f)
transformed.append(f)

new_data = op(*transformed)
return first.__class__.from_array(new_data, domain=domain_intersection)

_builtin_op.__name__ = builtin_name
return _builtin_op


_Value: TypeAlias = common.Field | core_defs.ScalarT
Expand Down Expand Up @@ -174,56 +174,50 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala

__call__ = None # type: ignore[assignment] # TODO: remap

__abs__ = _make_unary_array_field_intrinsic_func("abs", "abs")
__abs__ = _make_builtin("abs", "abs")

__neg__ = _make_unary_array_field_intrinsic_func("neg", "negative")
__neg__ = _make_builtin("neg", "negative")

__pos__ = _make_unary_array_field_intrinsic_func("pos", "positive")
__add__ = __radd__ = _make_builtin("add", "add")

__add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add")
__pos__ = _make_builtin("pos", "positive")

__sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract")
__sub__ = __rsub__ = _make_builtin("sub", "subtract")

__mul__ = __rmul__ = _make_binary_array_field_intrinsic_func("mul", "multiply")
__mul__ = __rmul__ = _make_builtin("mul", "multiply")

__truediv__ = __rtruediv__ = _make_binary_array_field_intrinsic_func("div", "divide")
__truediv__ = __rtruediv__ = _make_builtin("div", "divide")

__floordiv__ = __rfloordiv__ = _make_binary_array_field_intrinsic_func(
"floordiv", "floor_divide"
)
__floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide")

__pow__ = _make_binary_array_field_intrinsic_func("pow", "power")
__pow__ = _make_builtin("pow", "power")

__mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod")
__mod__ = __rmod__ = _make_builtin("mod", "mod")

def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")(
self, other
)
return _make_builtin("logical_and", "logical_and")(self, other)
raise NotImplementedError("`__and__` not implemented for non-`bool` fields.")

__rand__ = __and__

def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other)
return _make_builtin("logical_or", "logical_or")(self, other)
raise NotImplementedError("`__or__` not implemented for non-`bool` fields.")

__ror__ = __or__

def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")(
self, other
)
return _make_builtin("logical_xor", "logical_xor")(self, other)
raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.")

__rxor__ = __xor__

def __invert__(self) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_unary_array_field_intrinsic_func("invert", "invert")(self)
return _make_builtin("invert", "invert")(self)
raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.")

def _slice(
Expand All @@ -241,7 +235,7 @@ def _slice(
return new_domain, slice_


# -- Specialized implementations for intrinsic operations on array fields --
# -- Specialized implementations for builtin operations on array fields --

NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined]
NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined]
Expand All @@ -254,19 +248,18 @@ def _slice(
):
if name in ["abs", "power", "gamma"]:
continue
NdArrayField.register_builtin_func(
getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name)
)
NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name))

NdArrayField.register_builtin_func(
fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined]
fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined]
)
NdArrayField.register_builtin_func(
fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined]
fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined]
)
NdArrayField.register_builtin_func(
fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined]
fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined]
)
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _np_cp_setitem(
Expand Down
23 changes: 22 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ def builtin_function(fun: Callable[_P, _R]) -> BuiltInFunction[_R, _P]:
return BuiltInFunction(fun)


MaskT = TypeVar("MaskT", bound=Field)
FieldT = TypeVar("FieldT", bound=Union[Field, gt4py_defs.Scalar, Tuple])


class WhereBuiltinFunction(
BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT]
):
def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R:
if isinstance(true_field, tuple) or isinstance(false_field, tuple):
if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)):
raise ValueError(
f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages
)
if len(true_field) != len(false_field):
raise ValueError(
"Tuple of different size not allowed."
) # TODO(havogt) find a strategy to unify parsing and embedded error messages
return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R`
return super().__call__(mask, true_field, false_field)


@builtin_function
def neighbor_sum(
field: Field,
Expand Down Expand Up @@ -164,7 +185,7 @@ def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /)
raise NotImplementedError()


@builtin_function
@WhereBuiltinFunction
def where(
mask: Field,
true_field: Field | gt4py_defs.ScalarT | Tuple,
Expand Down
54 changes: 54 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,60 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati
assert np.allclose(result.ndarray, expected)


def test_where_builtin(nd_array_implementation):
cond = np.asarray([True, False])
true_ = np.asarray([1.0, 2.0], dtype=np.float32)
false_ = np.asarray([3.0, 4.0], dtype=np.float32)

field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]]
expected = np.where(cond, true_, false_)

result = fbuiltins.where(*field_inputs)
assert np.allclose(result.ndarray, expected)


def test_where_builtin_different_domain(nd_array_implementation):
cond = np.asarray([True, False])
true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32)

cond_field = common.field(
nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2})
)
true_field = common.field(
nd_array_implementation.asarray(true_),
domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}),
)
false_field = common.field(
nd_array_implementation.asarray(false_),
domain=common.domain({JDim: common.UnitRange(-1, 3)}),
)

expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1])

result = fbuiltins.where(cond_field, true_field, false_field)
assert np.allclose(result.ndarray, expected)


def test_where_builtin_with_tuple(nd_array_implementation):
cond = np.asarray([True, False])
true0 = np.asarray([1.0, 2.0], dtype=np.float32)
false0 = np.asarray([3.0, 4.0], dtype=np.float32)
true1 = np.asarray([11.0, 12.0], dtype=np.float32)
false1 = np.asarray([13.0, 14.0], dtype=np.float32)

expected0 = np.where(cond, true0, false0)
expected1 = np.where(cond, true1, false1)

cond_field = _make_field(cond, nd_array_implementation, dtype=bool)
field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1])
field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1])

result = fbuiltins.where(cond_field, field_true, field_false)
assert np.allclose(result[0].ndarray, expected0)
assert np.allclose(result[1].ndarray, expected1)


def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation):
inp_a = [-1.0, 4.2, 42]
inp_b = [2.0, 3.0, -3.0]
Expand Down

0 comments on commit 0ba4fd5

Please sign in to comment.