Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Support for Array Api namespace as allocator #1771

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [
'Topic :: Scientific/Engineering :: Physics'
]
dependencies = [
"array-api-compat>=1.9.1;python_version>='3.10'",
"astunparse>=1.6.3;python_version<'3.9'",
'attrs>=21.3',
'black>=22.3',
Expand Down Expand Up @@ -237,8 +238,9 @@ module = 'gt4py.next.iterator.runtime'
[tool.pytest.ini_options]
markers = [
'all: special marker that skips all tests',
'requires_atlas: tests that require `atlas4py` bindings package',
'requires_dace: tests that require `dace` package',
'requires_atlas: tests that require the `atlas4py` bindings package',
'requires_dace: tests that require the `dace` package',
'requires_jax: tests that require the `jax` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_constant_fields: tests that require backend support for constant fields',
Expand All @@ -264,6 +266,7 @@ markers = [
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'slices_out_argument: tests that slice the out argument in a field_operator call',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
50 changes: 49 additions & 1 deletion src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Iterator,
Literal,
Protocol,
Self,
Sequence,
Tuple,
Type,
Expand Down Expand Up @@ -405,6 +406,7 @@ class DeviceType(enum.IntEnum):
MetalDeviceTyping,
VPIDeviceTyping,
ROCMDeviceTyping,
covariant=True,
)


Expand Down Expand Up @@ -454,7 +456,7 @@ def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ...

def any(self) -> bool: ...

def __getitem__(self, item: Any) -> NDArrayObject: ...
def __getitem__(self, item: Any) -> Self: ...

def __abs__(self) -> NDArrayObject: ...

Expand Down Expand Up @@ -505,3 +507,49 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...
def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...

def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ...


class MutableNDArrayObject(NDArrayObject, Protocol):
def __setitem__(self, index: Any, value: Any) -> None: ...


class ArrayApiNamespace(Protocol):
def empty(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def zeros(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def ones(self, shape: Sequence[int], *, dtype: Any = None, device: Any = None) -> Any: ...
def full(
self, shape: Sequence[int], fill_value: Scalar, *, dtype: Any = None, device: Any = None
) -> Any: ...
def asarray(self, obj: Any, *, dtype: Any = None, copy: Any = None) -> Any: ...

# @property # once all relevant implementations have this attribute
# def __array_api_version__(self) -> str: ... # noqa: ERA001

# TODO(havogt): add relevant methods and attributes or wait for the standard to provide it, see e.g. https://github.com/data-apis/array-api/issues/697


def is_array_api_namespace(obj: Any) -> TypeGuard[ArrayApiNamespace]:
# return hasattr(obj, "__array_api_version__") # noqa: ERA001 # once all relevant implementations have this attribute
return (
hasattr(obj, "empty")
and hasattr(obj, "zeros")
and hasattr(obj, "ones")
and hasattr(obj, "full")
and hasattr(obj, "asarray")
)


def to_array_api_dtype(xp: ArrayApiNamespace, dtype_: DTypeLike | None) -> Any:
"""
Converts a GT4Py `DTypeLike` to the dtype object of the given Array API namespace.

Note: For convenience `None` is passed-through as it has a consistent meaning in all Array API implementations.
"""
if dtype_ is None:
return None
else:
dtype_ = dtype(dtype_)
assert (
dtype_.tensor_shape == ()
) # TODO(havogt): support tensor shapes (or remove from our DType)
return getattr(xp, dtype_.scalar_type.__name__)
Loading
Loading