diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f4e35b5533..e55e13a38d 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -75,6 +75,9 @@ class Dimension: def __str__(self): return f"{self.value}[{self.kind}]" + def __call__(self, val: int) -> NamedIndex: + return self, val + class Infinity(enum.Enum): """Describes an unbounded `UnitRange`.""" @@ -272,7 +275,10 @@ def unit_range(r: RangeLike) -> UnitRange: NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType -AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +NamedSlice: TypeAlias = ( + slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ +) +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] RelativeIndexSequence: TypeAlias = tuple[ @@ -307,6 +313,10 @@ def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: ) +def is_named_slice(obj: AnyIndexSpec) -> TypeGuard[NamedRange]: + return isinstance(obj, slice) and (is_named_index(obj.start) and is_named_index(obj.stop)) + + def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 94efe4d61d..f9201da247 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -146,3 +146,32 @@ def _find_index_of_dim( if dim == d: return i return None + + +def canonicalize_any_index_sequence( + index: common.AnyIndexSpec, +) -> common.AnyIndexSpec: + # TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice` + new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index + if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index): + new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement + return new_index + + +def _named_slice_to_named_range( + idx: common.NamedSlice, +) -> common.NamedRange | common.NamedSlice: + assert hasattr(idx, "start") and hasattr(idx, "stop") + if common.is_named_slice(idx): + idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = idx.start[0], idx.start[1], idx.stop[0], idx.stop[1] # type: ignore[attr-defined] + if idx_start_0 != idx_stop_0: + raise IndexError( + f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." + ) + assert isinstance(idx_start_1, int) and isinstance(idx_stop_1, int) + return (idx_start_0, common.UnitRange(idx_start_1, idx_stop_1)) + if common.is_named_index(idx.start) and idx.stop is None: + raise IndexError(f"Upper bound needs to be specified for {idx}.") + if common.is_named_index(idx.stop) and idx.start is None: + raise IndexError(f"Lower bound needs to be specified for {idx}.") + return idx diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1bdb7161ec..38aab09df1 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -301,6 +301,7 @@ def __invert__(self) -> NdArrayField: def _slice( self, index: common.AnyIndexSpec ) -> tuple[common.Domain, common.RelativeIndexSequence]: + index = embedded_common.canonicalize_any_index_sequence(index) new_domain = embedded_common.sub_domain(self.domain, index) index_sequence = common.as_any_index_sequence(index) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 05ebd02352..1c9887cd22 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -147,6 +147,8 @@ def __call__(self, *args): def make_node(o): if isinstance(o, Node): return o + if isinstance(o, common.Dimension): + return AxisLiteral(value=o.value) if callable(o): if o.__name__ == "": return lambdadef(o) @@ -156,8 +158,6 @@ def make_node(o): return OffsetLiteral(value=o.value) if isinstance(o, core_defs.Scalar): return im.literal_from_value(o) - if isinstance(o, common.Dimension): - return AxisLiteral(value=o.value) if isinstance(o, tuple): return _f("make_tuple", *(make_node(arg) for arg in o)) if o is None: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index de511fdabb..91f15ee936 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -19,7 +19,12 @@ from gt4py.next import common from gt4py.next.common import UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions -from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain +from gt4py.next.embedded.common import ( + _slice_range, + canonicalize_any_index_sequence, + iterate_domain, + sub_domain, +) @pytest.mark.parametrize( @@ -147,3 +152,31 @@ def test_iterate_domain(): testee = list(iterate_domain(domain)) assert testee == ref + + +@pytest.mark.parametrize( + "slices, expected", + [ + [slice(I(3), I(4)), ((I, common.UnitRange(3, 4)),)], + [ + (slice(J(3), J(6)), slice(I(3), I(5))), + ((J, common.UnitRange(3, 6)), (I, common.UnitRange(3, 5))), + ], + [slice(I(1), J(7)), IndexError], + [ + slice(I(1), None), + IndexError, + ], + [ + slice(None, K(8)), + IndexError, + ], + ], +) +def test_slicing(slices, expected): + if expected is IndexError: + with pytest.raises(IndexError): + canonicalize_any_index_sequence(slices) + else: + testee = canonicalize_any_index_sequence(slices) + assert testee == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 66189cf9eb..49f74a566b 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -463,6 +463,49 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): assert indexed_field.domain.dims == expected_dimensions +def test_absolute_indexing_dim_sliced(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + indexed_field_1 = field[JDim(8) : JDim(10), IDim(5) : IDim(9)] + expected = field[(IDim, UnitRange(5, 9)), (JDim, UnitRange(8, 10))] + + assert common.is_field(indexed_field_1) + assert indexed_field_1 == expected + + +def test_absolute_indexing_dim_sliced_single_slice(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + indexed_field_1 = field[KDim(11)] + indexed_field_2 = field[(KDim, 11)] + + assert common.is_field(indexed_field_1) + assert indexed_field_1 == indexed_field_2 + + +def test_absolute_indexing_wrong_dim_sliced(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + + with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'JDim' and 'IDim'."): + field[JDim(8) : IDim(10)] + + +def test_absolute_indexing_empty_dim_sliced(): + domain = common.Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + ) + field = common._field(np.ones((5, 10, 15)), domain=domain) + with pytest.raises(IndexError, match="Lower bound needs to be specified"): + field[: IDim(10)] + + def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain)