Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Enable jax testing #1429

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1216a88
make common.field private
havogt Nov 28, 2023
fa191a9
add allocators to test matrix
havogt Nov 28, 2023
f0b6bb6
rename exclusion_matrices to definitions
havogt Nov 29, 2023
a652e0e
Merge remote-tracking branch 'upstream/main' into enable_cupy_and_jax…
havogt Jan 4, 2024
d928039
store allocator
havogt Jan 24, 2024
6fa9f8f
based on array_ns
havogt Jan 25, 2024
816dafb
cleanup
havogt Jan 25, 2024
db5ecfa
Merge remote-tracking branch 'upstream/main' into enable_cupy_and_jax…
havogt Jan 25, 2024
ca7aa6c
requires_gpu
havogt Jan 25, 2024
0a0dc50
refactor backend selection
havogt Jan 25, 2024
a29c9ca
refactor weirdness in skip list for embedded
havogt Jan 25, 2024
034016b
fix import
havogt Jan 26, 2024
2f63bac
fix testcase
havogt Jan 26, 2024
2d7b0cf
implement execution and allocator descriptor for tests
havogt Jan 26, 2024
845810f
cleanups
havogt Jan 26, 2024
8aba24d
renaming
havogt Jan 26, 2024
c907349
missing file
havogt Jan 26, 2024
8238e8d
more renames
havogt Jan 26, 2024
470a36e
Update tests/next_tests/definitions.py
havogt Jan 29, 2024
b0bf88f
address review comment
havogt Jan 29, 2024
cdf0a32
fix import
havogt Jan 29, 2024
3f511d8
Merge remote-tracking branch 'upstream/main' into enable_jax_testing
havogt Jan 29, 2024
2dd11a5
add jax testing
havogt Jan 29, 2024
5185534
test exclusion for bool fields
havogt Jan 29, 2024
c417431
unrelated cleanup
havogt Jan 30, 2024
3cf8f6b
Merge remote-tracking branch 'upstream/main' into enable_jax_testing
havogt Nov 4, 2024
d2e6fc5
allocators and constructors
havogt Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ markers = [
'requires_atlas: tests that require `atlas4py` bindings package',
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'requires_jax: tests that require the `jax` package',
'starts_from_gtir_program: tests that require backend to start lowering from GTIR program',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_constant_fields: tests that require backend support for constant fields',
Expand All @@ -264,6 +265,7 @@ markers = [
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_max_over: tests that use the max_over builtin',
'uses_bool_field: tests that use a bool field',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
Expand Down
53 changes: 52 additions & 1 deletion src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]:
return valid_layout_map


def c_layout_mapper(dims: Sequence[common.Dimension]) -> core_allocators.BufferLayoutMap:
layout_map = tuple(range(len(dims)))
assert core_allocators.is_valid_layout_map(layout_map)

return layout_map


if TYPE_CHECKING:
__horizontal_first_layout_mapper: FieldLayoutMapper = horizontal_first_layout_mapper

Expand All @@ -207,6 +214,18 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]:
device_allocators: dict[core_defs.DeviceType, FieldBufferAllocatorProtocol] = {}


class CLayoutCPUFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CPUDeviceTyping]):
"""A field buffer allocator for CPU devices that uses a C-style layout."""

def __init__(self) -> None:
super().__init__(
device_type=core_defs.DeviceType.CPU,
array_utils=core_allocators.numpy_array_utils,
layout_mapper=c_layout_mapper,
byte_alignment=1,
)


class StandardCPUFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CPUDeviceTyping]):
"""A field buffer allocator for CPU devices that uses a horizontal-first layout mapper and 64-byte alignment."""

Expand All @@ -221,9 +240,41 @@ def __init__(self) -> None:

device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator()


assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU])

try:
# TODO use pattern from GPU allocation (InvalidBufferAllocator)
import jax.numpy as jnp
except ImportError:
jnp = None

if jnp:
from jax import config

config.update("jax_enable_x64", True)

class StandardJAXCPUFieldBufferAllocator(
FieldBufferAllocatorProtocol[core_defs.CPUDeviceTyping]
):
@property
def __gt_device_type__(self) -> core_defs.CPUDeviceTyping:
return core_defs.DeviceType.CPU

def __gt_allocate__(
self,
domain: common.Domain,
dtype: core_defs.DType[core_defs.ScalarT],
device_id: int = 0,
aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position
) -> core_allocators.TensorBuffer[core_defs.CPUDeviceTyping, core_defs.ScalarT]:
tensor_buffer = CLayoutCPUFieldBufferAllocator().__gt_allocate__(
domain, dtype, device_id, aligned_index
)
object.__setattr__(
tensor_buffer, "ndarray", jnp.from_dlpack(tensor_buffer.ndarray)
) # TODO mutating a frozen object
return tensor_buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return jax.empty(shape)



@dataclasses.dataclass(frozen=True)
class InvalidFieldBufferAllocator(FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]):
Expand Down
14 changes: 12 additions & 2 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,8 +1076,18 @@ def __setitem__(
index: common.AnyIndexSpec,
value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT,
) -> None:
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.")
target_domain, target_slice = self._slice(index)

if isinstance(value, NdArrayField):
if not value.domain == target_domain:
raise ValueError(
f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}."
)
value = value.ndarray

assert hasattr(self._ndarray, "at")
# TODO must not update a field of a frozen obj
object.__setattr__(self, "_ndarray", self._ndarray.at[target_slice].set(value))

common._field.register(jnp.ndarray, JaxArrayField.from_array)

Expand Down
10 changes: 9 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,18 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):

@dataclasses.dataclass(frozen=True)
class EmbeddedDummyBackend:
allocator: next_allocators.FieldBufferAllocatorProtocol
allocator: next_allocators.FieldBufferAllocatorProtocol # TODO make it field constructor


numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator())
cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator())
jax_execution = EmbeddedDummyBackend(next_allocators.StandardJAXCPUFieldBufferAllocator())


class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum):
NUMPY_EXECUTION = "next_tests.definitions.numpy_execution"
CUPY_EXECUTION = "next_tests.definitions.cupy_execution"
JAX_EXECUTION = "next_tests.definitions.jax_execution"


class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
Expand Down Expand Up @@ -114,6 +116,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_CARTESIAN_SHIFT = "uses_cartesian_shift"
USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift"
USES_MAX_OVER = "uses_max_over"
USES_BOOL_FIELD = "uses_bool_field"
USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"
Expand Down Expand Up @@ -174,6 +177,11 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
BACKEND_SKIP_TEST_MATRIX = {
EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST,
EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST,
EmbeddedIds.JAX_EXECUTION: EMBEDDED_SKIP_LIST
+ [
# dlpack support for `bool` arrays is not yet available in jax, see https://github.com/google/jax/issues/19352
(USES_BOOL_FIELD, XFAIL, UNSUPPORTED_MESSAGE)
],
OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __gt_allocator__(
pytest.param(
next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu
),
pytest.param(
next_tests.definitions.EmbeddedIds.JAX_EXECUTION, marks=pytest.mark.requires_jax
),
pytest.param(
next_tests.definitions.OptionalProgramBackendId.DACE_CPU,
marks=pytest.mark.requires_dace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:


@pytest.mark.uses_tuple_returns
@pytest.mark.uses_bool_field
def test_astype_on_tuples(cartesian_case):
@gtx.field_operator
def field_op_returning_a_tuple(
Expand Down Expand Up @@ -474,6 +475,7 @@ def cast_nested_tuple(
)


@pytest.mark.uses_bool_field
def test_astype_bool_field(cartesian_case):
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]:
Expand All @@ -485,6 +487,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]:
)


@pytest.mark.uses_bool_field
@pytest.mark.parametrize("inp", [0.0, 2.0])
def test_astype_bool_scalar(cartesian_case, inp):
@gtx.field_operator
Expand All @@ -510,6 +513,7 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]:
)


@pytest.mark.uses_bool_field
@pytest.mark.uses_dynamic_offsets
def test_offset_field(cartesian_case):
ref = np.full(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField:


@pytest.mark.uses_tuple_returns
@pytest.mark.uses_bool_field
def test_conditional_nested_tuple(cartesian_case):
@gtx.field_operator
def conditional_nested_tuple(
Expand Down Expand Up @@ -382,6 +383,7 @@ def simple_broadcast(inp: cases.IField) -> cases.IJField:
)


@pytest.mark.uses_bool_field
def test_conditional(cartesian_case):
@gtx.field_operator
def conditional(
Expand All @@ -406,6 +408,7 @@ def conditional(
)


@pytest.mark.uses_bool_field
def test_conditional_promotion(cartesian_case):
@gtx.field_operator
def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases.IFloatField:
Expand All @@ -431,6 +434,7 @@ def conditional_promotion(a: cases.IFloatField) -> cases.IFloatField:


@pytest.mark.uses_cartesian_shift
@pytest.mark.uses_bool_field
def test_conditional_shifted(cartesian_case):
@gtx.field_operator
def conditional_shifted(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def mod_fieldop(inp1: cases.IField) -> cases.IField:
cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2)


@pytest.mark.uses_bool_field
def test_bit_xor(cartesian_case):
@gtx.field_operator
def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField:
Expand All @@ -93,6 +94,7 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie
cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2)


@pytest.mark.uses_bool_field
def test_bit_and(cartesian_case):
@gtx.field_operator
def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField:
Expand All @@ -105,6 +107,7 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField:
cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2)


@pytest.mark.uses_bool_field
def test_bit_or(cartesian_case):
@gtx.field_operator
def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField:
Expand All @@ -128,6 +131,7 @@ def uneg(inp: cases.IField) -> cases.IField:
cases.verify_with_default_data(cartesian_case, uneg, ref=lambda inp1: -inp1)


@pytest.mark.uses_bool_field
def test_unary_neg_float_conversion(cartesian_case):
@gtx.field_operator
def uneg_float() -> cases.IFloatField:
Expand All @@ -140,6 +144,7 @@ def uneg_float() -> cases.IFloatField:
cases.verify(cartesian_case, uneg_float, out=out, ref=ref)


@pytest.mark.uses_bool_field
def test_unary_neg_bool_conversion(cartesian_case):
@gtx.field_operator
def uneg_bool() -> cases.IBoolField:
Expand All @@ -152,6 +157,7 @@ def uneg_bool() -> cases.IBoolField:
cases.verify(cartesian_case, uneg_bool, out=out, ref=ref)


@pytest.mark.uses_bool_field
def test_unary_invert(cartesian_case):
@gtx.field_operator
def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField:
Expand Down Expand Up @@ -212,6 +218,7 @@ def roots_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> cases.IFl
)


@pytest.mark.uses_bool_field
def test_is_values(cartesian_case):
@gtx.field_operator
def is_isinf_fieldop(inp1: cases.IFloatField) -> cases.IBoolField:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class setup:


@pytest.mark.uses_tuple_returns
@pytest.mark.uses_bool_field
@pytest.mark.uses_scan_requiring_projector
def test_solve_nonhydro_stencil_52_like_z_q(test_setup):
cases.verify(
Expand All @@ -226,6 +227,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup):


@pytest.mark.uses_tuple_returns
@pytest.mark.uses_bool_field
def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup):
if (
test_setup.case.backend
Expand Down Expand Up @@ -253,6 +255,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup):


@pytest.mark.uses_tuple_returns
@pytest.mark.uses_bool_field
def test_solve_nonhydro_stencil_52_like(test_setup):
if (
test_setup.case.backend
Expand All @@ -274,6 +277,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup):
assert np.allclose(test_setup.w_ref, test_setup.w.asnumpy())


@pytest.mark.uses_bool_field
@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup):
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_ffront_lap(cartesian_case):
in_field,
out_field,
inout=out_field[1:-1, 1:-1],
ref=lap_ref(in_field.ndarray),
ref=lap_ref(in_field.asnumpy()),
)


Expand Down Expand Up @@ -125,5 +125,5 @@ def test_ffront_laplap(cartesian_case):
in_field,
out_field,
inout=out_field[2:-2, 2:-2],
ref=lap_ref(lap_ref(in_field.array_ns.asarray(in_field.ndarray))),
ref=lap_ref(lap_ref(in_field.as_numpy())),
)
Loading