Skip to content

Commit

Permalink
feat[next] Enable embedded field view in ffront_tests (#1361)
Browse files Browse the repository at this point in the history
Enables field view in ffront_tests

New exclusion markers for some cases
- cartesian and unstructured shifts
- scan
- check for a very concrete error message in parsing: we should match this later in embedded

Adds the following features to embedded:
- support for scalar broadcast, astype, binary functions
- adds `__ne__` and `__eq__` to Field 

TODOs:
- full comparison operators for UnitRange
- full comparison operators for Fields
  • Loading branch information
havogt authored Nov 17, 2023
1 parent 39d1c09 commit ecd0b68
Show file tree
Hide file tree
Showing 21 changed files with 293 additions and 112 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,11 @@ markers = [
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_returns: tests that require backend support for tuple results',
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields'
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_scan: tests that uses scan',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
testpaths = 'tests'
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def shape(self) -> tuple[int, ...]:
def dtype(self) -> Any:
...

def astype(self, dtype: npt.DTypeLike) -> NDArrayObject:
...

def __getitem__(self, item: Any) -> NDArrayObject:
...

Expand Down
20 changes: 19 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,24 @@ def __getitem__(self, index: int | slice) -> int | UnitRange:
else:
raise IndexError("UnitRange index out of range")

def __and__(self, other: Set[Any]) -> UnitRange:
def __and__(self, other: Set[int]) -> UnitRange:
if isinstance(other, UnitRange):
start = max(self.start, other.start)
stop = min(self.stop, other.stop)
return UnitRange(start, stop)
else:
raise NotImplementedError("Can only find the intersection between UnitRange instances.")

def __le__(self, other: Set[int]):
if isinstance(other, UnitRange):
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
return False
else:
return Set.__le__(self, other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented

def __str__(self) -> str:
return f"({self.start}:{self.stop})"

Expand Down Expand Up @@ -486,6 +496,14 @@ def __neg__(self) -> Field:
def __invert__(self) -> Field:
"""Only defined for `Field` of value type `bool`."""

@abc.abstractmethod
def __eq__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool`
...

@abc.abstractmethod
def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool`
...

@abc.abstractmethod
def __add__(self, other: Field | core_defs.ScalarT) -> Field:
...
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def empty(
(3, 3)
"""
dtype = core_defs.dtype(dtype)
if allocator is None and device is None:
device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0)
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
Expand Down
30 changes: 19 additions & 11 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,22 @@ def from_array(
/,
*,
domain: common.DomainLike,
dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike
dtype: Optional[core_defs.DTypeLike] = None,
) -> NdArrayField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type)
xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)

if dtype_like is not None:
assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type
if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES)

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
assert all(
len(r) == s or (s == 1 and r == common.UnitRange.infinity())
for r, s in zip(domain.ranges, array.shape)
)
assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape))

return cls(domain, array)

Expand Down Expand Up @@ -194,6 +191,10 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala

__mod__ = __rmod__ = _make_builtin("mod", "mod")

__ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool`

__eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool`

def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_builtin("logical_and", "logical_and")(self, other)
Expand Down Expand Up @@ -285,7 +286,7 @@ def _np_cp_setitem(
_nd_array_implementations = [np]


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np

Expand All @@ -298,7 +299,7 @@ class NumPyArrayField(NdArrayField):
if cp:
_nd_array_implementations.append(cp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

Expand All @@ -310,7 +311,7 @@ class CuPyArrayField(NdArrayField):
if jnp:
_nd_array_implementations.append(jnp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class JaxArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = jnp

Expand Down Expand Up @@ -351,6 +352,13 @@ def _builtins_broadcast(
NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast)


def _astype(field: NdArrayField, type_: type) -> NdArrayField:
return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain)


NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field


def _get_slices_from_domain_slice(
domain: common.Domain,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
Expand Down
55 changes: 36 additions & 19 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators
from gt4py.next import allocators as next_allocators, common
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -171,14 +171,14 @@ class Program:
past_node: past.Program
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = None,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
grid_type: Optional[GridType] = None,
) -> Program:
source_def = SourceDefinition.from_function(definition)
Expand Down Expand Up @@ -282,27 +282,23 @@ def itir(self) -> itir.FencilDefinition:
)

def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None:
if (
self.backend is None and DEFAULT_BACKEND is None
): # TODO(havogt): for now enable embedded execution by setting DEFAULT_BACKEND to None
self.definition(*args, **kwargs)
return

rewritten_args, size_args, kwargs = self._process_args(args, kwargs)

if not self.backend:
if self.backend is None:
warnings.warn(
UserWarning(
f"Field View Program '{self.itir.id}': Using default ({DEFAULT_BACKEND}) backend."
f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend."
)
)
backend = self.backend or DEFAULT_BACKEND

ppi.ensure_processor_kind(backend, ppi.ProgramExecutor)
self.definition(*rewritten_args, **kwargs)
return

ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor)
if "debug" in kwargs:
debug(self.itir)

backend(
self.backend(
self.itir,
*rewritten_args,
*size_args,
Expand Down Expand Up @@ -547,14 +543,14 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
foast_node: OperatorNodeT
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = None,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
grid_type: Optional[GridType] = None,
*,
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator,
Expand Down Expand Up @@ -687,9 +683,9 @@ def __call__(
# if we are reaching this from a program call.
if "out" in kwargs:
out = kwargs.pop("out")
if "offset_provider" in kwargs:
offset_provider = kwargs.pop("offset_provider", None)
if self.backend is not None:
# "out" and "offset_provider" -> field_operator as program
offset_provider = kwargs.pop("offset_provider")
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
# deduce argument types
Expand All @@ -705,13 +701,34 @@ def __call__(
)
else:
# "out" -> field_operator called from program in embedded execution
out.ndarray[:] = self.definition(*args, **kwargs).ndarray[:]
# TODO(egparedes): put offset_provider in ctxt var here when implementing remap
domain = kwargs.pop("domain", None)
res = self.definition(*args, **kwargs)
_tuple_assign_field(
out, res, domain=None if domain is None else common.domain(domain)
)
return
else:
# field_operator called from other field_operator in embedded execution
assert self.backend is None
return self.definition(*args, **kwargs)


def _tuple_assign_field(
target: tuple[common.Field | tuple, ...] | common.Field,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: Optional[common.Domain],
):
if isinstance(target, tuple):
if not isinstance(source, tuple):
raise RuntimeError(f"Cannot assign {source} to {target}.")
for t, s in zip(target, source):
_tuple_assign_field(t, s, domain)
else:
domain = domain or target.domain
target[domain] = source[domain]


@typing.overload
def field_operator(
definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor]
Expand Down
Loading

0 comments on commit ecd0b68

Please sign in to comment.