Skip to content

Commit

Permalink
cleanup parts and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Feb 1, 2024
1 parent c7b01eb commit d51cfca
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 149 deletions.
112 changes: 57 additions & 55 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import embedded as itir_embedded


try:
Expand Down Expand Up @@ -390,55 +391,21 @@ def inverse_image(

assert common.UnitRange.is_finite(image_range)

# HANNES HACK
# map all negative indices (skip_values) to the smallest positiv index
smallest = xp.min(
self._ndarray, initial=xp.iinfo(xp.int32).max, where=self._ndarray >= 0
restricted_mask = (self._ndarray >= image_range.start) & (
self._ndarray < image_range.stop
)
clipped_array = xp.where(self._ndarray < 0, smallest, self._ndarray)
# END HANNES HACK

restricted_mask = (clipped_array >= image_range.start) & (
self._ndarray < image_range.stop
relative_ranges = _hypercube(
restricted_mask, xp, ignore_mask=self._ndarray == common.SKIP_VALUE
)
# indices of non-zero elements in each dimension
nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask)

new_dims = []
non_contiguous_dims = []

for i, dim_nnz_indices in enumerate(nnz):
# Check if the indices are contiguous
first_data_index = dim_nnz_indices[0]
assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES)
last_data_index = dim_nnz_indices[-1]
assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES)
indices, counts = xp.unique(dim_nnz_indices, return_counts=True)
dim_range = self._domain[i]

if len(xp.unique(counts)) == 1 and (
len(indices) == last_data_index - first_data_index + 1
):
idx_offset = dim_range[1].start
start = idx_offset + first_data_index
assert common.is_int_index(start)
stop = idx_offset + last_data_index + 1
assert common.is_int_index(stop)
new_dims.append(
common.named_range(
(
dim_range[0],
(start, stop),
)
)
)
else:
non_contiguous_dims.append(dim_range[0])

if non_contiguous_dims:
raise ValueError(
f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'."
)
if relative_ranges is None:
raise ValueError("Restriction generates non-contiguous dimensions.")

new_dims = [
common.named_range((d, rr + ar.start))
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges)
]

self._cache[cache_key] = new_dims

Expand All @@ -460,6 +427,30 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ
__getitem__ = restrict


def _hypercube(
select: core_defs.NDArrayObject,
xp: ModuleType,
ignore_mask: Optional[core_defs.NDArrayObject] = None,
) -> Optional[list[common.UnitRange]]:
"""
Return the hypercube that contains all True values and no False values or `None` if no such hypercube exists.
If `ignore_mask` is given, the selected values are ignored.
"""
nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select)

slices = tuple(
slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz
)
hcube = select[tuple(slices)]
if ignore_mask is not None:
hcube |= ignore_mask[tuple(slices)]
if not xp.all(hcube):
return None

return [common.UnitRange(s.start, s.stop) for s in slices]


# -- Specialized implementations for builtin operations on array fields --

NdArrayField.register_builtin_func(
Expand Down Expand Up @@ -491,7 +482,9 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[
def _make_reduction(
builtin_name: str, array_builtin_name: str, initial_value_op: Callable
) -> Callable[
...,
NdArrayField[common.DimsT, core_defs.ScalarT],
]:
Expand All @@ -503,23 +496,26 @@ def _builtin_op(
if axis not in field.domain.dims:
raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.")
reduce_dim_index = field.domain.dims.index(axis)
# HANNES HACK
current_offset_provider = embedded_context.offset_provider.get(None)
assert current_offset_provider is not None
offset_definition = current_offset_provider[
axis.value
] # assumes offset and local dimension have same name
# TODO mapping of connectivity to field dimensions
assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider)
# TODO unclear in case of multiple local dimensions (probably just chain)
# END HANNES HACK
new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis])

# TODO add test that requires broadcasting
masked_array = field.array_ns.where(
field.array_ns.asarray(offset_definition.table) != common.SKIP_VALUE,
field.ndarray,
initial_value_op(field),
)

return field.__class__.from_array(
getattr(field.array_ns, array_builtin_name)(
field.ndarray,
masked_array,
axis=reduce_dim_index,
initial=0, # set proper inital value
where=offset_definition.table >= 0,
),
domain=new_domain,
)
Expand All @@ -528,9 +524,15 @@ def _builtin_op(
return _builtin_op


NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum"))
NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max"))
NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min"))
NdArrayField.register_builtin_func(
fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum", lambda x: x.dtype.scalar_type(0))
)
NdArrayField.register_builtin_func(
fbuiltins.max_over, _make_reduction("max_over", "max", lambda x: x.array_ns.min(x._ndarray))
)
NdArrayField.register_builtin_func(
fbuiltins.min_over, _make_reduction("min_over", "min", lambda x: x.array_ns.max(x._ndarray))
)


# -- Concrete array implementations --
Expand Down
1 change: 0 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
XFAIL,
UNSUPPORTED_MESSAGE,
), # we can't extract the field type from scan args
(USES_MESH_WITH_SKIP_VALUES, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
# floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gt4py.next.program_processors import processor_interface as ppi

from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
exec_alloc_descriptor,
)
from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import (
assert_close,
Expand Down Expand Up @@ -70,11 +70,8 @@ def pnabla(
return compute_pnabla(pp, S_M[0], sign, vol), compute_pnabla(pp, S_M[1], sign, vol)


def test_ffront_compute_zavgS(fieldview_backend):
allocator = fieldview_backend
fieldview_backend = (
fieldview_backend if hasattr(fieldview_backend, "executor") else None
) # TODO this pattern is dangerous
def test_ffront_compute_zavgS(exec_alloc_descriptor):
executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator

setup = nabla_setup()

Expand All @@ -87,26 +84,24 @@ def test_ffront_compute_zavgS(fieldview_backend):
atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False
)

compute_zavgS.with_backend(fieldview_backend)(
pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v}
)
compute_zavgS.with_backend(executor)(pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v})

assert_close(-199755464.25741270, np.min(zavgS.asnumpy()))
assert_close(388241977.58389181, np.max(zavgS.asnumpy()))


def test_ffront_nabla(fieldview_backend):
fieldview_backend = fieldview_backend if hasattr(fieldview_backend, "executor") else None
def test_ffront_nabla(exec_alloc_descriptor):
executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator

setup = nabla_setup()

sign = gtx.as_field([Vertex, V2EDim], setup.sign_field)
pp = gtx.as_field([Vertex], setup.input_field)
S_M = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields))
vol = gtx.as_field([Vertex], setup.vol_field)
sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator)
pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator)
S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields))
vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator)

pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size)))
pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size)))
pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator)
pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator)

e2v = gtx.NeighborTableOffsetProvider(
atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False
Expand All @@ -115,7 +110,7 @@ def test_ffront_nabla(fieldview_backend):
atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7
)

pnabla.with_backend(fieldview_backend)(
pnabla.with_backend(executor)(
pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e}
)

Expand Down

This file was deleted.

Loading

0 comments on commit d51cfca

Please sign in to comment.