Skip to content

Commit

Permalink
Replace device="cpu" with a special object in numpy.array_api
Browse files Browse the repository at this point in the history
This way, it does not appear that "cpu" is a portable device object across
different array API compatible libraries. See
data-apis/array-api#626.

Original NumPy Commit: 3b20ad9c5ead16282c530cf48737aa3768a77f91
  • Loading branch information
asmeurer committed Dec 11, 2023
1 parent 136bdd7 commit 1a13e76
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 45 deletions.
10 changes: 8 additions & 2 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@

import numpy as np

# Placeholder object to represent the "cpu" device (the only device NumPy
# supports).
class _cpu_device:
def __repr__(self):
return "CPU_DEVICE"
CPU_DEVICE = _cpu_device()

class Array:
"""
Expand Down Expand Up @@ -1067,7 +1073,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
if device == CPU_DEVICE:
return self
raise ValueError(f"Unsupported device {device!r}")

Expand All @@ -1082,7 +1088,7 @@ def dtype(self) -> Dtype:

@property
def device(self) -> Device:
return "cpu"
return CPU_DEVICE

# Note: mT is new in array API spec (see matrix_transpose)
@property
Expand Down
48 changes: 24 additions & 24 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def asarray(
"""
# _array_object imports in this file are inside the functions to avoid
# circular imports
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
if copy in (False, np._CopyMode.IF_NEEDED):
# Note: copy=False is not yet implemented in np.asarray
Expand Down Expand Up @@ -86,10 +86,10 @@ def arange(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))

Expand All @@ -105,10 +105,10 @@ def empty(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty(shape, dtype=dtype))

Expand All @@ -121,10 +121,10 @@ def empty_like(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty_like(x._array, dtype=dtype))

Expand All @@ -143,10 +143,10 @@ def eye(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))

Expand All @@ -169,10 +169,10 @@ def full(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
if isinstance(fill_value, Array) and fill_value.ndim == 0:
fill_value = fill_value._array
Expand All @@ -197,10 +197,10 @@ def full_like(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
res = np.full_like(x._array, fill_value, dtype=dtype)
if res.dtype not in _all_dtypes:
Expand All @@ -225,10 +225,10 @@ def linspace(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))

Expand Down Expand Up @@ -264,10 +264,10 @@ def ones(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.ones(shape, dtype=dtype))

Expand All @@ -280,10 +280,10 @@ def ones_like(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.ones_like(x._array, dtype=dtype))

Expand Down Expand Up @@ -327,10 +327,10 @@ def zeros(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.zeros(shape, dtype=dtype))

Expand All @@ -343,9 +343,9 @@ def zeros_like(
See its docstring for more information.
"""
from ._array_object import Array
from ._array_object import Array, CPU_DEVICE

_check_valid_dtype(dtype)
if device not in ["cpu", None]:
if device not in [CPU_DEVICE, None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.zeros_like(x._array, dtype=dtype))
4 changes: 2 additions & 2 deletions array_api_strict/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Protocol,
)

from ._array_object import Array
from ._array_object import Array, CPU_DEVICE
from numpy import (
dtype,
int8,
Expand All @@ -50,7 +50,7 @@ class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...

Device = Literal["cpu"]
Device = type(CPU_DEVICE)

Dtype = dtype[Union[
int8,
Expand Down
11 changes: 7 additions & 4 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from .. import ones, asarray, reshape, result_type, all, equal
from .._array_object import Array
from .._array_object import Array, CPU_DEVICE
from .._dtypes import (
_all_dtypes,
_boolean_dtypes,
Expand Down Expand Up @@ -311,12 +311,15 @@ def test_python_scalar_construtors():

def test_device_property():
a = ones((3, 4))
assert a.device == 'cpu'
assert a.device == CPU_DEVICE
assert a.device != 'cpu'

assert all(equal(a.to_device('cpu'), a))
assert all(equal(a.to_device(CPU_DEVICE), a))
assert_raises(ValueError, lambda: a.to_device('cpu'))
assert_raises(ValueError, lambda: a.to_device('gpu'))

assert all(equal(asarray(a, device='cpu'), a))
assert all(equal(asarray(a, device=CPU_DEVICE), a))
assert_raises(ValueError, lambda: asarray(a, device='cpu'))
assert_raises(ValueError, lambda: asarray(a, device='gpu'))

def test_array_properties():
Expand Down
38 changes: 25 additions & 13 deletions array_api_strict/tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
zeros_like,
)
from .._dtypes import float32, float64
from .._array_object import Array
from .._array_object import Array, CPU_DEVICE


def test_asarray_errors():
Expand All @@ -30,7 +30,8 @@ def test_asarray_errors():
# Preferably this would be OverflowError
# assert_raises(OverflowError, lambda: asarray([2**100]))
assert_raises(TypeError, lambda: asarray([2**100]))
asarray([1], device="cpu") # Doesn't error
asarray([1], device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: asarray([1], device="cpu"))
assert_raises(ValueError, lambda: asarray([1], device="gpu"))

assert_raises(ValueError, lambda: asarray([1], dtype=int))
Expand Down Expand Up @@ -58,77 +59,88 @@ def test_asarray_copy():


def test_arange_errors():
arange(1, device="cpu") # Doesn't error
arange(1, device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: arange(1, device="cpu"))
assert_raises(ValueError, lambda: arange(1, device="gpu"))
assert_raises(ValueError, lambda: arange(1, dtype=int))
assert_raises(ValueError, lambda: arange(1, dtype="i"))


def test_empty_errors():
empty((1,), device="cpu") # Doesn't error
empty((1,), device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: empty((1,), device="cpu"))
assert_raises(ValueError, lambda: empty((1,), device="gpu"))
assert_raises(ValueError, lambda: empty((1,), dtype=int))
assert_raises(ValueError, lambda: empty((1,), dtype="i"))


def test_empty_like_errors():
empty_like(asarray(1), device="cpu") # Doesn't error
empty_like(asarray(1), device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: empty_like(asarray(1), device="cpu"))
assert_raises(ValueError, lambda: empty_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: empty_like(asarray(1), dtype="i"))


def test_eye_errors():
eye(1, device="cpu") # Doesn't error
eye(1, device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: eye(1, device="cpu"))
assert_raises(ValueError, lambda: eye(1, device="gpu"))
assert_raises(ValueError, lambda: eye(1, dtype=int))
assert_raises(ValueError, lambda: eye(1, dtype="i"))


def test_full_errors():
full((1,), 0, device="cpu") # Doesn't error
full((1,), 0, device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: full((1,), 0, device="cpu"))
assert_raises(ValueError, lambda: full((1,), 0, device="gpu"))
assert_raises(ValueError, lambda: full((1,), 0, dtype=int))
assert_raises(ValueError, lambda: full((1,), 0, dtype="i"))


def test_full_like_errors():
full_like(asarray(1), 0, device="cpu") # Doesn't error
full_like(asarray(1), 0, device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="cpu"))
assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu"))
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int))
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i"))


def test_linspace_errors():
linspace(0, 1, 10, device="cpu") # Doesn't error
linspace(0, 1, 10, device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: linspace(0, 1, 10, device="cpu"))
assert_raises(ValueError, lambda: linspace(0, 1, 10, device="gpu"))
assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float))
assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype="f"))


def test_ones_errors():
ones((1,), device="cpu") # Doesn't error
ones((1,), device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: ones((1,), device="cpu"))
assert_raises(ValueError, lambda: ones((1,), device="gpu"))
assert_raises(ValueError, lambda: ones((1,), dtype=int))
assert_raises(ValueError, lambda: ones((1,), dtype="i"))


def test_ones_like_errors():
ones_like(asarray(1), device="cpu") # Doesn't error
ones_like(asarray(1), device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: ones_like(asarray(1), device="cpu"))
assert_raises(ValueError, lambda: ones_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: ones_like(asarray(1), dtype="i"))


def test_zeros_errors():
zeros((1,), device="cpu") # Doesn't error
zeros((1,), device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: zeros((1,), device="cpu"))
assert_raises(ValueError, lambda: zeros((1,), device="gpu"))
assert_raises(ValueError, lambda: zeros((1,), dtype=int))
assert_raises(ValueError, lambda: zeros((1,), dtype="i"))


def test_zeros_like_errors():
zeros_like(asarray(1), device="cpu") # Doesn't error
zeros_like(asarray(1), device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="cpu"))
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
Expand Down

0 comments on commit 1a13e76

Please sign in to comment.