Skip to content

Commit

Permalink
feat[next]: add tests with mesh with skip values (#1433)
Browse files Browse the repository at this point in the history
- Adds a mesh with skip values
- Define `common.SKIP_VALUE = -1` instead of using `-1` explicitly
- Skip tests with that mesh in embedded (will come in a next PR).
  • Loading branch information
havogt authored Jan 31, 2024
1 parent 28ed830 commit adf3a3c
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 100 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ markers = [
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Any,
Callable,
ClassVar,
Final,
Generic,
Never,
Optional,
Expand Down Expand Up @@ -1073,4 +1074,9 @@ def register_builtin_func(
@classmethod
def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]:
return cls._builtin_func_map.get(func, NotImplemented)
return cls._builtin_func_map.get(func, NotImplemented)


#: Numeric value used to represent missing values in connectivities.
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
SKIP_VALUE: Final[int] = -1
12 changes: 11 additions & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,16 @@ def execute_shift(
for i, p in reversed(list(enumerate(new_entry))):
# first shift applies to the last sparse dimensions of that axis type
if p is None:
offset_implementation = offset_provider[tag]
assert isinstance(offset_implementation, common.Connectivity)
cur_index = pos[offset_implementation.origin_axis.value]
assert common.is_int_index(cur_index)
if offset_implementation.mapped_index(cur_index, index) in [
None,
common.SKIP_VALUE,
]:
return None

new_entry[i] = index
break
# the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard
Expand All @@ -549,7 +559,7 @@ def execute_shift(
assert common.is_int_index(cur_index)
if offset_implementation.mapped_index(cur_index, index) in [
None,
-1,
common.SKIP_VALUE,
]:
return None
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_CARTESIAN_SHIFT = "uses_cartesian_shift"
USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift"
USES_MAX_OVER = "uses_max_over"
USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"

# Skip messages (available format keys: 'marker', 'backend')
Expand Down Expand Up @@ -170,6 +171,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
XFAIL,
UNSUPPORTED_MESSAGE,
), # we can't extract the field type from scan args
(USES_MESH_WITH_SKIP_VALUES, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
# floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136
Expand Down
32 changes: 15 additions & 17 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,25 @@

from next_tests import definitions as test_definitions
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixture and aliases
C2E,
C2V,
E2V,
V2E,
C2EDim,
C2VDim,
Cell,
E2VDim,
Edge,
IDim,
Ioff,
JDim,
Joff,
KDim,
Koff,
V2EDim,
Vertex,
exec_alloc_descriptor,
reduction_setup,
mesh_descriptor,
)


Expand All @@ -65,16 +73,6 @@
CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type]
EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type]

# TODO(ricoh): unify the following with the `ffront_test_utils.reduction_setup`
# fixture if `ffront_test_utils.reduction_setup` is not completely superseded
# by `unstructured_case`.
V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL)
E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL)
C2EDim = gtx.Dimension("C2E", kind=common.DimensionKind.LOCAL)
V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim))
E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim))
C2E = gtx.FieldOffset("C2E", source=Edge, target=(Cell, C2EDim))

ScalarValue: TypeAlias = core_defs.Scalar
FieldValue: TypeAlias = gtx.Field
FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...]
Expand Down Expand Up @@ -489,17 +487,17 @@ def cartesian_case(

@pytest.fixture
def unstructured_case(
reduction_setup, # noqa: F811 # fixtures
mesh_descriptor, # noqa: F811 # fixtures
exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, # noqa: F811 # fixtures
):
yield Case(
exec_alloc_descriptor.executor,
offset_provider=reduction_setup.offset_provider,
offset_provider=mesh_descriptor.offset_provider,
default_sizes={
Vertex: reduction_setup.num_vertices,
Edge: reduction_setup.num_edges,
Cell: reduction_setup.num_cells,
KDim: reduction_setup.k_levels,
Vertex: mesh_descriptor.num_vertices,
Edge: mesh_descriptor.num_edges,
Cell: mesh_descriptor.num_cells,
KDim: 10,
},
grid_type=common.GridType.UNSTRUCTURED,
allocator=exec_alloc_descriptor.allocator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import types
from collections import namedtuple
from typing import Any, Optional, TypeVar
from typing import Any, Protocol, TypeVar

import numpy as np
import pytest

import gt4py.next as gtx
from gt4py.next import common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import ir as itir
from gt4py.next.program_processors import processor_interface as ppi
Expand Down Expand Up @@ -118,18 +120,41 @@ def debug_itir(tree):
Cell = gtx.Dimension("Cell")
EdgeOffset = gtx.FieldOffset("EdgeOffset", source=Edge, target=(Edge,))

V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL)
E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL)
C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL)
C2VDim = gtx.Dimension("C2V", kind=gtx.DimensionKind.LOCAL)
V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim))
E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim))
C2E = gtx.FieldOffset("C2E", source=Edge, target=(Cell, C2EDim))
C2V = gtx.FieldOffset("C2V", source=Vertex, target=(Cell, C2VDim))

size = 10


@pytest.fixture
def reduction_setup():
class MeshDescriptor(Protocol):
@property
def name(self) -> str: ...

@property
def num_vertices(self) -> int: ...

@property
def num_cells(self) -> int: ...

@property
def num_edges(self) -> int: ...

@property
def num_levels(self) -> int: ...

@property
def offset_provider(self) -> dict[str, common.Connectivity]: ...


def simple_mesh() -> MeshDescriptor:
num_vertices = 9
num_cells = 8
k_levels = 10
v2edim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL)
e2vdim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL)
c2vdim = gtx.Dimension("C2V", kind=gtx.DimensionKind.LOCAL)
c2edim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL)

v2e_arr = np.array(
[
Expand Down Expand Up @@ -183,57 +208,115 @@ def reduction_setup():
assert all(len(row) == 2 for row in e2v_arr)
e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType)

yield namedtuple(
"ReductionSetup",
return types.SimpleNamespace(
name="simple_mesh",
num_vertices=num_vertices,
num_edges=np.int32(num_edges),
num_cells=num_cells,
offset_provider={
V2E.value: gtx.NeighborTableOffsetProvider(
v2e_arr, Vertex, Edge, 4, has_skip_values=False
),
E2V.value: gtx.NeighborTableOffsetProvider(
e2v_arr, Edge, Vertex, 2, has_skip_values=False
),
C2V.value: gtx.NeighborTableOffsetProvider(
c2v_arr, Cell, Vertex, 4, has_skip_values=False
),
C2E.value: gtx.NeighborTableOffsetProvider(
c2e_arr, Cell, Edge, 4, has_skip_values=False
),
},
)


def skip_value_mesh() -> MeshDescriptor:
"""Mesh with skip values from the GT4Py quickstart guide."""

num_vertices = 7
num_cells = 6
num_edges = 12

v2e_arr = np.array(
[
[1, 8, 7, 0, -1],
[2, 8, 1, -1, -1],
[3, 9, 8, 2, -1],
[4, 10, 3, -1, -1],
[5, 11, 4, -1, -1],
[0, 6, 4, -1, -1],
[6, 7, 9, 10, 11],
],
dtype=gtx.IndexType,
)

e2v_arr = np.array(
[
[0, 5],
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6],
[6, 0],
[0, 2],
[2, 6],
[3, 6],
[4, 6],
],
dtype=gtx.IndexType,
)

c2v_arr = np.array(
[
[0, 6, 5],
[0, 2, 6],
[0, 1, 2],
[2, 3, 6],
[3, 4, 6],
[4, 5, 6],
],
dtype=gtx.IndexType,
)

c2e_arr = np.array(
[
"num_vertices",
"num_edges",
"num_cells",
"k_levels",
"V2EDim",
"E2VDim",
"C2VDim",
"C2EDim",
"V2E",
"E2V",
"C2V",
"C2E",
"inp",
"out",
"offset_provider",
"v2e_table",
"e2v_table",
[0, 6, 7], # cell 0 (neighbors: edge 0, edge 6, edge 7)
[7, 8, 9], # cell 1
[1, 2, 8], # cell 2
[3, 9, 10], # cell 3
[4, 10, 11], # cell 4
[5, 6, 11], # cell 5
],
)(
dtype=gtx.IndexType,
)

return types.SimpleNamespace(
name="skip_value_mesh",
num_vertices=num_vertices,
num_edges=num_edges,
num_cells=num_cells,
k_levels=k_levels,
V2EDim=v2edim,
E2VDim=e2vdim,
C2VDim=c2vdim,
C2EDim=c2edim,
V2E=gtx.FieldOffset("V2E", source=Edge, target=(Vertex, v2edim)),
E2V=gtx.FieldOffset("E2V", source=Vertex, target=(Edge, e2vdim)),
C2V=gtx.FieldOffset("C2V", source=Vertex, target=(Cell, c2vdim)),
C2E=gtx.FieldOffset("C2E", source=Edge, target=(Cell, c2edim)),
# inp=gtx.index_field(edge, dtype=np.int64), # TODO enable once we support gtx.index_fields in bindings
inp=gtx.as_field([Edge], np.arange(num_edges, dtype=np.int32)),
out=gtx.as_field([Vertex], np.zeros([num_vertices], dtype=np.int32)),
offset_provider={
"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4, has_skip_values=False),
"E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False),
"C2V": gtx.NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False),
"C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False),
V2E.value: gtx.NeighborTableOffsetProvider(
v2e_arr, Vertex, Edge, 5, has_skip_values=True
),
E2V.value: gtx.NeighborTableOffsetProvider(
e2v_arr, Edge, Vertex, 2, has_skip_values=False
),
C2V.value: gtx.NeighborTableOffsetProvider(
c2v_arr, Cell, Vertex, 3, has_skip_values=False
),
C2E.value: gtx.NeighborTableOffsetProvider(
c2e_arr, Cell, Edge, 3, has_skip_values=False
),
},
v2e_table=v2e_arr,
e2v_table=e2v_arr,
) # type: ignore
)


__all__ = [
"exec_alloc_descriptor",
"reduction_setup",
"mesh_descriptor",
"debug_itir",
"DimsType",
"DType",
Expand All @@ -249,3 +332,14 @@ def reduction_setup():
"EdgeOffset",
"size",
]


@pytest.fixture(
params=[
simple_mesh(),
pytest.param(skip_value_mesh(), marks=pytest.mark.uses_mesh_with_skip_values),
],
ids=lambda p: p.name,
)
def mesh_descriptor(request) -> MeshDescriptor:
yield request.param
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from next_tests.integration_tests.cases import cartesian_case
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
reduction_setup,
mesh_descriptor,
)


Expand Down
Loading

0 comments on commit adf3a3c

Please sign in to comment.