From d51cfcac5c8579c32d0018b430ba2e2ee365c636 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 1 Feb 2024 15:04:42 +0100 Subject: [PATCH] cleanup parts and tests --- src/gt4py/next/embedded/nd_array_field.py | 112 +++++++++--------- tests/next_tests/definitions.py | 1 - .../ffront_tests/test_ffront_fvm_nabla.py | 31 ++--- .../ffront_tests/test_skip_value.py | 73 ------------ .../embedded_tests/test_nd_array_field.py | 84 ++++++++++++- 5 files changed, 152 insertions(+), 149 deletions(-) delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 508cfde9f9..fef2577d0a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -28,6 +28,7 @@ from gt4py.next import common from gt4py.next.embedded import common as embedded_common, context as embedded_context from gt4py.next.ffront import fbuiltins +from gt4py.next.iterator import embedded as itir_embedded try: @@ -390,55 +391,21 @@ def inverse_image( assert common.UnitRange.is_finite(image_range) - # HANNES HACK - # map all negative indices (skip_values) to the smallest positiv index - smallest = xp.min( - self._ndarray, initial=xp.iinfo(xp.int32).max, where=self._ndarray >= 0 + restricted_mask = (self._ndarray >= image_range.start) & ( + self._ndarray < image_range.stop ) - clipped_array = xp.where(self._ndarray < 0, smallest, self._ndarray) - # END HANNES HACK - restricted_mask = (clipped_array >= image_range.start) & ( - self._ndarray < image_range.stop + relative_ranges = _hypercube( + restricted_mask, xp, ignore_mask=self._ndarray == common.SKIP_VALUE ) - # indices of non-zero elements in each dimension - nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask) - - new_dims = [] - non_contiguous_dims = [] - - for i, dim_nnz_indices in enumerate(nnz): - # Check if the indices are contiguous - first_data_index = dim_nnz_indices[0] - assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES) - last_data_index = dim_nnz_indices[-1] - assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES) - indices, counts = xp.unique(dim_nnz_indices, return_counts=True) - dim_range = self._domain[i] - - if len(xp.unique(counts)) == 1 and ( - len(indices) == last_data_index - first_data_index + 1 - ): - idx_offset = dim_range[1].start - start = idx_offset + first_data_index - assert common.is_int_index(start) - stop = idx_offset + last_data_index + 1 - assert common.is_int_index(stop) - new_dims.append( - common.named_range( - ( - dim_range[0], - (start, stop), - ) - ) - ) - else: - non_contiguous_dims.append(dim_range[0]) - if non_contiguous_dims: - raise ValueError( - f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'." - ) + if relative_ranges is None: + raise ValueError("Restriction generates non-contiguous dimensions.") + + new_dims = [ + common.named_range((d, rr + ar.start)) + for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges) + ] self._cache[cache_key] = new_dims @@ -460,6 +427,30 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ __getitem__ = restrict +def _hypercube( + select: core_defs.NDArrayObject, + xp: ModuleType, + ignore_mask: Optional[core_defs.NDArrayObject] = None, +) -> Optional[list[common.UnitRange]]: + """ + Return the hypercube that contains all True values and no False values or `None` if no such hypercube exists. + + If `ignore_mask` is given, the selected values are ignored. + """ + nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select) + + slices = tuple( + slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz + ) + hcube = select[tuple(slices)] + if ignore_mask is not None: + hcube |= ignore_mask[tuple(slices)] + if not xp.all(hcube): + return None + + return [common.UnitRange(s.start, s.stop) for s in slices] + + # -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func( @@ -491,7 +482,9 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[ +def _make_reduction( + builtin_name: str, array_builtin_name: str, initial_value_op: Callable +) -> Callable[ ..., NdArrayField[common.DimsT, core_defs.ScalarT], ]: @@ -503,23 +496,26 @@ def _builtin_op( if axis not in field.domain.dims: raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") reduce_dim_index = field.domain.dims.index(axis) - # HANNES HACK current_offset_provider = embedded_context.offset_provider.get(None) assert current_offset_provider is not None offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - # TODO mapping of connectivity to field dimensions + assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) # TODO unclear in case of multiple local dimensions (probably just chain) - # END HANNES HACK new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + # TODO add test that requires broadcasting + masked_array = field.array_ns.where( + field.array_ns.asarray(offset_definition.table) != common.SKIP_VALUE, + field.ndarray, + initial_value_op(field), + ) + return field.__class__.from_array( getattr(field.array_ns, array_builtin_name)( - field.ndarray, + masked_array, axis=reduce_dim_index, - initial=0, # set proper inital value - where=offset_definition.table >= 0, ), domain=new_domain, ) @@ -528,9 +524,15 @@ def _builtin_op( return _builtin_op -NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum")) -NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max")) -NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min")) +NdArrayField.register_builtin_func( + fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum", lambda x: x.dtype.scalar_type(0)) +) +NdArrayField.register_builtin_func( + fbuiltins.max_over, _make_reduction("max_over", "max", lambda x: x.array_ns.min(x._ndarray)) +) +NdArrayField.register_builtin_func( + fbuiltins.min_over, _make_reduction("min_over", "min", lambda x: x.array_ns.max(x._ndarray)) +) # -- Concrete array implementations -- diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index c95292d702..d9060a2474 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -171,7 +171,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args - (USES_MESH_WITH_SKIP_VALUES, XFAIL, UNSUPPORTED_MESSAGE), ] GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index efbf4ddbc7..c3971d6fe5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -23,7 +23,7 @@ from gt4py.next.program_processors import processor_interface as ppi from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -70,11 +70,8 @@ def pnabla( return compute_pnabla(pp, S_M[0], sign, vol), compute_pnabla(pp, S_M[1], sign, vol) -def test_ffront_compute_zavgS(fieldview_backend): - allocator = fieldview_backend - fieldview_backend = ( - fieldview_backend if hasattr(fieldview_backend, "executor") else None - ) # TODO this pattern is dangerous +def test_ffront_compute_zavgS(exec_alloc_descriptor): + executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator setup = nabla_setup() @@ -87,26 +84,24 @@ def test_ffront_compute_zavgS(fieldview_backend): atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False ) - compute_zavgS.with_backend(fieldview_backend)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} - ) + compute_zavgS.with_backend(executor)(pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v}) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) assert_close(388241977.58389181, np.max(zavgS.asnumpy())) -def test_ffront_nabla(fieldview_backend): - fieldview_backend = fieldview_backend if hasattr(fieldview_backend, "executor") else None +def test_ffront_nabla(exec_alloc_descriptor): + executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator setup = nabla_setup() - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) + pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) + S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) - pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) + pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) e2v = gtx.NeighborTableOffsetProvider( atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False @@ -115,7 +110,7 @@ def test_ffront_nabla(fieldview_backend): atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 ) - pnabla.with_backend(fieldview_backend)( + pnabla.with_backend(executor)( pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py deleted file mode 100644 index 6657ed1b62..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py +++ /dev/null @@ -1,73 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import numpy as np - -import gt4py.next as gtx -from gt4py.next import float64, neighbor_sum - - -CellDim = gtx.Dimension("Cell") -EdgeDim = gtx.Dimension("Edge") -E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) -E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) - -edge_to_cell_table = np.array( - [ - [0, -1], # edge 0 (neighbours: cell 0) - [2, -1], # edge 1 - [2, -1], # edge 2 - [3, -1], # edge 3 - [4, -1], # edge 4 - [5, -1], # edge 5 - [0, 5], # edge 6 (neighbours: cell 0, cell 5) - [0, 1], # edge 7 - [1, 2], # edge 8 - [1, 3], # edge 9 - [3, 4], # edge 10 - [4, 5], # edge 11 - ] -) - -E2C_offset_provider = gtx.NeighborTableOffsetProvider( - edge_to_cell_table, EdgeDim, CellDim, 2, has_skip_values=True -) - - -@gtx.field_operator -def sum_adjacent_cells(cells: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: - # type of cells(E2C) is gtx.Field[[CellDim, E2CDim], float64] - return neighbor_sum(cells(E2C), axis=E2CDim) - - -@gtx.program -def run_sum_adjacent_cells( - cells: gtx.Field[[CellDim], float64], out: gtx.Field[[EdgeDim], float64] -): - sum_adjacent_cells(cells, out=out) - - -cell_values = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) -edge_values = gtx.as_field([EdgeDim], np.zeros((12,))) - -run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) - -print(np.clip(edge_to_cell_table, 0, 5)) -ref = cell_values.ndarray[np.clip(edge_to_cell_table, 0, 5)] -print(ref) -ref = np.sum(ref, initial=0, where=edge_to_cell_table >= 0, axis=1) - - -print("sum of adjacent cells: {}".format(edge_values.asnumpy())) -print(ref) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 70fa274457..ded3098f58 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -683,14 +683,14 @@ def test_connectivity_field_inverse_image_2d_domain(): codomain=V, ) - # e2c_conn: + # e2v_conn: # ---E2V---- # |[[0 0 2] # E [1 1 2] # | [2 2 2]] # Test contiguous and non-contiguous ranges. - # For the 'e2c_conn' defined above, the only valid range including 2 + # For the 'e2v_conn' defined above, the only valid range including 2 # is [0, 3). Otherwise, the inverse image would be non-contiguous. image_range = UnitRange(V_START, V_STOP) result = e2v_conn.inverse_image(image_range) @@ -742,3 +742,83 @@ def test_connectivity_field_inverse_image_non_contiguous(): with pytest.raises(ValueError, match="generates non-contiguous dimensions"): e2v_conn.inverse_image(UnitRange(V_START, V_STOP)) + + +def test_connectivity_field_inverse_image_2d_domain_skip_values(): + V = Dimension("V") + E = Dimension("E") + E2V = Dimension("E2V") + + V_START, V_STOP = 0, 3 + E_START, E_STOP = 0, 4 + E2V_START, E2V_STOP = 0, 4 + + e2v_conn = common._connectivity( + np.asarray([[-1, 0, 2, -1], [1, 1, 2, 2], [2, 2, -1, -1], [-1, 2, -1, -1]]), + domain=common.domain( + [ + common.named_range((E, (E_START, E_STOP))), + common.named_range((E2V, (E2V_START, E2V_STOP))), + ] + ), + codomain=V, + ) + + # e2v_conn: + # ---E2V--------- + # |[[-1 0 2 -1] + # E [ 1 1 2 2] + # | [ 2 2 -1 -1] + # | [-1 2 -1 -1]] + + image_range = UnitRange(V_START, V_STOP) + result = e2v_conn.inverse_image(image_range) + + assert len(result) == 2 + assert result[0] == (E, UnitRange(E_START, E_STOP)) + assert result[1] == (E2V, UnitRange(E2V_START, E2V_STOP)) + + result = e2v_conn.inverse_image(UnitRange(0, 2)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(0, 2)) + assert result[1] == (E2V, UnitRange(0, 2)) + + result = e2v_conn.inverse_image(UnitRange(0, 1)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(0, 1)) + assert result[1] == (E2V, UnitRange(1, 2)) + + result = e2v_conn.inverse_image(UnitRange(1, 2)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(1, 2)) + assert result[1] == (E2V, UnitRange(0, 2)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + result = e2v_conn.inverse_image(UnitRange(1, 3)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + result = e2v_conn.inverse_image(UnitRange(2, 3)) + + +@pytest.mark.parametrize( + "select, ignore_mask, expected", + [ + ([True, True, False], None, [(0, 2)]), + ([True, False, True], None, None), + ([True, False, True], [False, True, False], [(0, 3)]), + ([[False, False, False], [False, True, True]], None, [(1, 2), (1, 3)]), + ( + [[False, True, False], [False, True, True]], + [[False, False, True], [False, False, False]], + [(0, 2), (1, 3)], + ), + ], +) +def test_hypercube(select, ignore_mask, expected): + select = np.asarray(select) + ignore_mask = np.asarray(ignore_mask) if ignore_mask is not None else None + expected = [common.unit_range(e) for e in expected] if expected is not None else None + + result = nd_array_field._hypercube(select, np, ignore_mask=ignore_mask) + + assert result == expected