Skip to content

Commit

Permalink
Merge branch 'gt4py-workshop' of https://github.com/gridtools/gt4py i…
Browse files Browse the repository at this point in the history
…nto gt4py-workshop
  • Loading branch information
nfarabullini committed Nov 20, 2023
2 parents 5b1c813 + 1fce1d0 commit 4de1f7e
Show file tree
Hide file tree
Showing 28 changed files with 585 additions and 191 deletions.
5 changes: 4 additions & 1 deletion .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ tasks:
pip install --upgrade pip setuptools wheel
pip install -e .
pip install -r requirements-dev.txt
pip install -i https://test.pypi.org/simple/ atlas4py
pre-commit install --install-hooks
sed 's%# start templated%# start templated\nsource /workspace/gt4py/.venv/bin/activate%' /workspace/gt4py/.git/hooks/pre-commit -i
pip install jupyterlab
deactivate
command: |
source .venv/bin/activate
wget https://raw.githubusercontent.com/mwouts/jupytext/main/binder/labconfig/default_setting_overrides.json -P ~/.jupyter/labconfig/
jupyter lab --notebook-dir=/workspace/gt4py/docs/functional --LabApp.token='' --NotebookApp.allow_origin=* --LabApp.default_url='lab/tree/Workshop.md'
jupyter lab --notebook-dir=/workspace/gt4py/docs/user/next --LabApp.token='' --ServerApp.allow_remote_access=True --NotebookApp.allow_origin=* --LabApp.default_url='lab/tree/Workshop.md'
env:
PIP_SRC: _external_src
PRE_COMMIT_HOME: /workspace/.caches/pre-commit
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ markers = [
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_if_stmts: tests that require backend support for if-statements',
Expand All @@ -344,7 +343,11 @@ markers = [
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_returns: tests that require backend support for tuple results',
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields'
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_scan: tests that uses scan',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
testpaths = 'tests'
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def shape(self) -> tuple[int, ...]:
def dtype(self) -> Any:
...

def astype(self, dtype: npt.DTypeLike) -> NDArrayObject:
...

def __getitem__(self, item: Any) -> NDArrayObject:
...

Expand Down
20 changes: 19 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,24 @@ def __getitem__(self, index: int | slice) -> int | UnitRange:
else:
raise IndexError("UnitRange index out of range")

def __and__(self, other: Set[Any]) -> UnitRange:
def __and__(self, other: Set[int]) -> UnitRange:
if isinstance(other, UnitRange):
start = max(self.start, other.start)
stop = min(self.stop, other.stop)
return UnitRange(start, stop)
else:
raise NotImplementedError("Can only find the intersection between UnitRange instances.")

def __le__(self, other: Set[int]):
if isinstance(other, UnitRange):
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
return False
else:
return Set.__le__(self, other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented

def __str__(self) -> str:
return f"({self.start}:{self.stop})"

Expand Down Expand Up @@ -486,6 +496,14 @@ def __neg__(self) -> Field:
def __invert__(self) -> Field:
"""Only defined for `Field` of value type `bool`."""

@abc.abstractmethod
def __eq__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool`
...

@abc.abstractmethod
def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool`
...

@abc.abstractmethod
def __add__(self, other: Field | core_defs.ScalarT) -> Field:
...
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def empty(
(3, 3)
"""
dtype = core_defs.dtype(dtype)
if allocator is None and device is None:
device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0)
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
Expand Down
141 changes: 71 additions & 70 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from __future__ import annotations

import dataclasses
import functools
import operator
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar
Expand All @@ -39,40 +41,38 @@
jnp: Optional[ModuleType] = None # type:ignore[no-redef]


def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_unary_op(a: NdArrayField) -> common.Field:
xp = a.__class__.array_ns
def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
first = fields[0]
assert isinstance(first, NdArrayField)
xp = first.__class__.array_ns
op = getattr(xp, array_builtin_name)
new_data = op(a.ndarray)

return a.__class__.from_array(new_data, domain=a.domain)

_builtin_unary_op.__name__ = builtin_name
return _builtin_unary_op


def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable:
def _builtin_binary_op(a: NdArrayField, b: common.Field) -> common.Field:
xp = a.__class__.array_ns
op = getattr(xp, array_builtin_name)
if hasattr(b, "__gt_builtin_func__"): # common.is_field(b):
if not a.domain == b.domain:
domain_intersection = a.domain & b.domain
a_broadcasted = _broadcast(a, domain_intersection.dims)
b_broadcasted = _broadcast(b, domain_intersection.dims)
a_slices = _get_slices_from_domain_slice(a_broadcasted.domain, domain_intersection)
b_slices = _get_slices_from_domain_slice(b_broadcasted.domain, domain_intersection)
new_data = op(a_broadcasted.ndarray[a_slices], b_broadcasted.ndarray[b_slices])
return a.__class__.from_array(new_data, domain=domain_intersection)
new_data = op(a.ndarray, xp.asarray(b.ndarray))
else:
assert isinstance(b, core_defs.SCALAR_TYPES)
new_data = op(a.ndarray, b)

return a.__class__.from_array(new_data, domain=a.domain)

_builtin_binary_op.__name__ = builtin_name
return _builtin_binary_op
domain_intersection = functools.reduce(
operator.and_,
[f.domain for f in fields if common.is_field(f)],
common.Domain(dims=tuple(), ranges=tuple()),
)
transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = []
for f in fields:
if common.is_field(f):
if f.domain == domain_intersection:
transformed.append(xp.asarray(f.ndarray))
else:
f_broadcasted = _broadcast(f, domain_intersection.dims)
f_slices = _get_slices_from_domain_slice(
f_broadcasted.domain, domain_intersection
)
transformed.append(xp.asarray(f_broadcasted.ndarray[f_slices]))
else:
assert core_defs.is_scalar_type(f)
transformed.append(f)

new_data = op(*transformed)
return first.__class__.from_array(new_data, domain=domain_intersection)

_builtin_op.__name__ = builtin_name
return _builtin_op


_Value: TypeAlias = common.Field | core_defs.ScalarT
Expand Down Expand Up @@ -135,25 +135,22 @@ def from_array(
/,
*,
domain: common.DomainLike,
dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike
dtype: Optional[core_defs.DTypeLike] = None,
) -> NdArrayField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type)
xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)

if dtype_like is not None:
assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type
if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES)

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
assert all(
len(r) == s or (s == 1 and r == common.UnitRange.infinity())
for r, s in zip(domain.ranges, array.shape)
)
assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape))

return cls(domain, array)

Expand All @@ -174,56 +171,54 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala

__call__ = None # type: ignore[assignment] # TODO: remap

__abs__ = _make_unary_array_field_intrinsic_func("abs", "abs")
__abs__ = _make_builtin("abs", "abs")

__neg__ = _make_unary_array_field_intrinsic_func("neg", "negative")
__neg__ = _make_builtin("neg", "negative")

__pos__ = _make_unary_array_field_intrinsic_func("pos", "positive")
__add__ = __radd__ = _make_builtin("add", "add")

__add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add")
__pos__ = _make_builtin("pos", "positive")

__sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract")
__sub__ = __rsub__ = _make_builtin("sub", "subtract")

__mul__ = __rmul__ = _make_binary_array_field_intrinsic_func("mul", "multiply")
__mul__ = __rmul__ = _make_builtin("mul", "multiply")

__truediv__ = __rtruediv__ = _make_binary_array_field_intrinsic_func("div", "divide")
__truediv__ = __rtruediv__ = _make_builtin("div", "divide")

__floordiv__ = __rfloordiv__ = _make_binary_array_field_intrinsic_func(
"floordiv", "floor_divide"
)
__floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide")

__pow__ = _make_binary_array_field_intrinsic_func("pow", "power")
__pow__ = _make_builtin("pow", "power")

__mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod")
__mod__ = __rmod__ = _make_builtin("mod", "mod")

__ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool`

__eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool`

def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")(
self, other
)
return _make_builtin("logical_and", "logical_and")(self, other)
raise NotImplementedError("`__and__` not implemented for non-`bool` fields.")

__rand__ = __and__

def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other)
return _make_builtin("logical_or", "logical_or")(self, other)
raise NotImplementedError("`__or__` not implemented for non-`bool` fields.")

__ror__ = __or__

def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")(
self, other
)
return _make_builtin("logical_xor", "logical_xor")(self, other)
raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.")

__rxor__ = __xor__

def __invert__(self) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_unary_array_field_intrinsic_func("invert", "invert")(self)
return _make_builtin("invert", "invert")(self)
raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.")

def _slice(
Expand All @@ -241,7 +236,7 @@ def _slice(
return new_domain, slice_


# -- Specialized implementations for intrinsic operations on array fields --
# -- Specialized implementations for builtin operations on array fields --

NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined]
NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined]
Expand All @@ -254,19 +249,18 @@ def _slice(
):
if name in ["abs", "power", "gamma"]:
continue
NdArrayField.register_builtin_func(
getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name)
)
NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name))

NdArrayField.register_builtin_func(
fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined]
fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined]
)
NdArrayField.register_builtin_func(
fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined]
fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined]
)
NdArrayField.register_builtin_func(
fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined]
fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined]
)
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _np_cp_setitem(
Expand All @@ -292,7 +286,7 @@ def _np_cp_setitem(
_nd_array_implementations = [np]


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np

Expand All @@ -305,7 +299,7 @@ class NumPyArrayField(NdArrayField):
if cp:
_nd_array_implementations.append(cp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

Expand All @@ -317,7 +311,7 @@ class CuPyArrayField(NdArrayField):
if jnp:
_nd_array_implementations.append(jnp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class JaxArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = jnp

Expand Down Expand Up @@ -358,6 +352,13 @@ def _builtins_broadcast(
NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast)


def _astype(field: NdArrayField, type_: type) -> NdArrayField:
return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain)


NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field


def _get_slices_from_domain_slice(
domain: common.Domain,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
Expand Down
Loading

0 comments on commit 4de1f7e

Please sign in to comment.