diff --git a/pyproject.toml b/pyproject.toml index 64f08e671e..d87361e35f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,6 +238,7 @@ markers = [ 'requires_atlas: tests that require `atlas4py` bindings package', 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', + 'requires_jax: tests that require the `jax` package', 'starts_from_gtir_program: tests that require backend to start lowering from GTIR program', 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', @@ -264,6 +265,7 @@ markers = [ 'uses_cartesian_shift: tests that use a Cartesian connectivity', 'uses_unstructured_shift: tests that use a unstructured connectivity', 'uses_max_over: tests that use the max_over builtin', + 'uses_bool_field: tests that use a bool field', 'uses_mesh_with_skip_values: tests that use a mesh with skip values', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 864f8c1b09..f55e39a592 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -60,7 +60,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: ... + ) -> core_allocators._NDBuffer: ... def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: @@ -160,7 +160,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_allocators._NDBuffer: shape = domain.shape layout_map = self.layout_mapper(domain.dims) # TODO(egparedes): add support for non-empty aligned index values @@ -168,7 +168,7 @@ def __gt_allocate__( return self.buffer_allocator.allocate( shape, dtype, device_id, layout_map, self.byte_alignment, aligned_index - ) + ).ndarray if TYPE_CHECKING: @@ -199,6 +199,13 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]: return valid_layout_map +def c_layout_mapper(dims: Sequence[common.Dimension]) -> core_allocators.BufferLayoutMap: + layout_map = tuple(range(len(dims))) + assert core_allocators.is_valid_layout_map(layout_map) + + return layout_map + + if TYPE_CHECKING: __horizontal_first_layout_mapper: FieldLayoutMapper = horizontal_first_layout_mapper @@ -207,6 +214,18 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]: device_allocators: dict[core_defs.DeviceType, FieldBufferAllocatorProtocol] = {} +class CLayoutCPUFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CPUDeviceTyping]): + """A field buffer allocator for CPU devices that uses a C-style layout.""" + + def __init__(self) -> None: + super().__init__( + device_type=core_defs.DeviceType.CPU, + array_utils=core_allocators.numpy_array_utils, + layout_mapper=c_layout_mapper, + byte_alignment=1, + ) + + class StandardCPUFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CPUDeviceTyping]): """A field buffer allocator for CPU devices that uses a horizontal-first layout mapper and 64-byte alignment.""" @@ -221,9 +240,44 @@ def __init__(self) -> None: device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator() - assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU]) +try: + # TODO use pattern from GPU allocation (InvalidBufferAllocator) + import jax.numpy as jnp +except ImportError: + jnp = None + +if jnp: + from jax import config + + config.update("jax_enable_x64", True) + + class StandardJAXCPUFieldBufferAllocator( + FieldBufferAllocatorProtocol[core_defs.CPUDeviceTyping] + ): + @property + def __gt_device_type__(self) -> core_defs.CPUDeviceTyping: + return core_defs.DeviceType.CPU + + def __gt_allocate__( + self, + domain: common.Domain, + dtype: core_defs.DType[core_defs.ScalarT], + device_id: int = 0, + aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position + ) -> core_allocators.TensorBuffer[core_defs.CPUDeviceTyping, core_defs.ScalarT]: + return jnp.empty(domain.shape, dtype=dtype) + # TODO + # tensor_buffer = CLayoutCPUFieldBufferAllocator().__gt_allocate__( + # domain, dtype, device_id, aligned_index + # ) + # object.__setattr__( + # tensor_buffer, "ndarray", jnp.from_dlpack(tensor_buffer.ndarray) + # ) # TODO mutating a frozen object + # return tensor_buffer + ... + @dataclasses.dataclass(frozen=True) class InvalidFieldBufferAllocator(FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]): @@ -242,7 +296,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: + ) -> core_allocators._NDBuffer: raise self.exception @@ -299,7 +353,7 @@ def allocate( aligned_index: Optional[Sequence[common.NamedIndex]] = None, allocator: Optional[FieldBufferAllocationUtil] = None, device: Optional[core_defs.Device] = None, -) -> core_allocators.TensorBuffer: +) -> core_allocators._NDBuffer: """ Allocate a TensorBuffer for the given domain and device or allocator. diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index dd52559e85..3a74c7cd6a 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -80,7 +80,7 @@ def empty( buffer = next_allocators.allocate( domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device ) - res = common._field(buffer.ndarray, domain=domain) + res = common._field(buffer, domain=domain) assert isinstance(res, common.MutableField) assert isinstance(res, nd_array_field.NdArrayField) return res @@ -343,9 +343,9 @@ def as_connectivity( device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) # TODO(havogt): consider adding MutableNDArrayObject - buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] + buffer[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common._connectivity( - buffer.ndarray, codomain=codomain, domain=actual_domain, skip_value=skip_value + buffer, codomain=codomain, domain=actual_domain, skip_value=skip_value ) assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 655a1137e8..2a76420e36 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -1076,8 +1076,18 @@ def __setitem__( index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` - raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") + target_domain, target_slice = self._slice(index) + + if isinstance(value, NdArrayField): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self._ndarray, "at") + # TODO must not update a field of a frozen obj + object.__setattr__(self, "_ndarray", self._ndarray.at[target_slice].set(value)) common._field.register(jnp.ndarray, JaxArrayField.from_array) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 3fef43865b..fa475bb88b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -55,16 +55,18 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): @dataclasses.dataclass(frozen=True) class EmbeddedDummyBackend: - allocator: next_allocators.FieldBufferAllocatorProtocol + allocator: next_allocators.FieldBufferAllocatorProtocol # TODO make it field constructor numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator()) cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator()) +jax_execution = EmbeddedDummyBackend(next_allocators.StandardJAXCPUFieldBufferAllocator()) class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): NUMPY_EXECUTION = "next_tests.definitions.numpy_execution" CUPY_EXECUTION = "next_tests.definitions.cupy_execution" + JAX_EXECUTION = "next_tests.definitions.jax_execution" class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): @@ -114,6 +116,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_CARTESIAN_SHIFT = "uses_cartesian_shift" USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift" USES_MAX_OVER = "uses_max_over" +USES_BOOL_FIELD = "uses_bool_field" USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" @@ -174,6 +177,11 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, + EmbeddedIds.JAX_EXECUTION: EMBEDDED_SKIP_LIST + + [ + # dlpack support for `bool` arrays is not yet available in jax, see https://github.com/google/jax/issues/19352 + (USES_BOOL_FIELD, XFAIL, UNSUPPORTED_MESSAGE) + ], OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 333a2dae28..3f462d8f91 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -58,6 +58,9 @@ def __gt_allocator__( pytest.param( next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu ), + pytest.param( + next_tests.definitions.EmbeddedIds.JAX_EXECUTION, marks=pytest.mark.requires_jax + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7540d52fb3..3516daf9e2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -415,6 +415,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: @pytest.mark.uses_tuple_returns +@pytest.mark.uses_bool_field def test_astype_on_tuples(cartesian_case): @gtx.field_operator def field_op_returning_a_tuple( @@ -474,6 +475,7 @@ def cast_nested_tuple( ) +@pytest.mark.uses_bool_field def test_astype_bool_field(cartesian_case): @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: @@ -485,6 +487,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: ) +@pytest.mark.uses_bool_field @pytest.mark.parametrize("inp", [0.0, 2.0]) def test_astype_bool_scalar(cartesian_case, inp): @gtx.field_operator @@ -510,6 +513,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]: ) +@pytest.mark.uses_bool_field @pytest.mark.uses_dynamic_offsets def test_offset_field(cartesian_case): ref = np.full( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 29966c30ad..feafa025bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -309,6 +309,7 @@ def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: @pytest.mark.uses_tuple_returns +@pytest.mark.uses_bool_field def test_conditional_nested_tuple(cartesian_case): @gtx.field_operator def conditional_nested_tuple( @@ -382,6 +383,7 @@ def simple_broadcast(inp: cases.IField) -> cases.IJField: ) +@pytest.mark.uses_bool_field def test_conditional(cartesian_case): @gtx.field_operator def conditional( @@ -406,6 +408,7 @@ def conditional( ) +@pytest.mark.uses_bool_field def test_conditional_promotion(cartesian_case): @gtx.field_operator def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases.IFloatField: @@ -431,6 +434,7 @@ def conditional_promotion(a: cases.IFloatField) -> cases.IFloatField: @pytest.mark.uses_cartesian_shift +@pytest.mark.uses_bool_field def test_conditional_shifted(cartesian_case): @gtx.field_operator def conditional_shifted( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 89c341e9a6..61f56c675d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -81,6 +81,7 @@ def mod_fieldop(inp1: cases.IField) -> cases.IField: cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) +@pytest.mark.uses_bool_field def test_bit_xor(cartesian_case): @gtx.field_operator def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: @@ -93,6 +94,7 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2) +@pytest.mark.uses_bool_field def test_bit_and(cartesian_case): @gtx.field_operator def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: @@ -105,6 +107,7 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2) +@pytest.mark.uses_bool_field def test_bit_or(cartesian_case): @gtx.field_operator def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: @@ -128,6 +131,7 @@ def uneg(inp: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, uneg, ref=lambda inp1: -inp1) +@pytest.mark.uses_bool_field def test_unary_neg_float_conversion(cartesian_case): @gtx.field_operator def uneg_float() -> cases.IFloatField: @@ -140,6 +144,7 @@ def uneg_float() -> cases.IFloatField: cases.verify(cartesian_case, uneg_float, out=out, ref=ref) +@pytest.mark.uses_bool_field def test_unary_neg_bool_conversion(cartesian_case): @gtx.field_operator def uneg_bool() -> cases.IBoolField: @@ -152,6 +157,7 @@ def uneg_bool() -> cases.IBoolField: cases.verify(cartesian_case, uneg_bool, out=out, ref=ref) +@pytest.mark.uses_bool_field def test_unary_invert(cartesian_case): @gtx.field_operator def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: @@ -212,6 +218,7 @@ def roots_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFl ) +@pytest.mark.uses_bool_field def test_is_values(cartesian_case): @gtx.field_operator def is_isinf_fieldop(inp1: cases.IFloatField) -> cases.IBoolField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 505879a506..e6a019107f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -207,6 +207,7 @@ class setup: @pytest.mark.uses_tuple_returns +@pytest.mark.uses_bool_field @pytest.mark.uses_scan_requiring_projector def test_solve_nonhydro_stencil_52_like_z_q(test_setup): cases.verify( @@ -226,6 +227,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): @pytest.mark.uses_tuple_returns +@pytest.mark.uses_bool_field def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): if ( test_setup.case.backend @@ -253,6 +255,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns +@pytest.mark.uses_bool_field def test_solve_nonhydro_stencil_52_like(test_setup): if ( test_setup.case.backend @@ -274,6 +277,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup): assert np.allclose(test_setup.w_ref, test_setup.w.asnumpy()) +@pytest.mark.uses_bool_field @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): if ( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 850a20ff7e..d027679a05 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -95,7 +95,7 @@ def test_ffront_lap(cartesian_case): in_field, out_field, inout=out_field[1:-1, 1:-1], - ref=lap_ref(in_field.ndarray), + ref=lap_ref(in_field.asnumpy()), ) @@ -125,5 +125,5 @@ def test_ffront_laplap(cartesian_case): in_field, out_field, inout=out_field[2:-2, 2:-2], - ref=lap_ref(lap_ref(in_field.array_ns.asarray(in_field.ndarray))), + ref=lap_ref(lap_ref(in_field.as_numpy())), )