Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow python scalars in result_type: 2024.12 revision #119

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions array_api_strict/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def isdtype(
else:
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")

def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.

Expand All @@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
# too many extra type promotions like int64 + uint64 -> float64, and does
# value-based casting on scalar arrays.
A = []
scalars = []
for a in arrays_and_dtypes:
if isinstance(a, Array):
a = a.dtype
elif isinstance(a, (bool, int, float, complex)):
scalars.append(a)
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

# remove python scalars
A = [a for a in A if not isinstance(a, (bool, int, float, complex))]

if len(A) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
result = A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
result = t

if len(scalars) == 0:
return result

if get_array_api_strict_flags()['api_version'] <= '2023.12':
raise TypeError("result_type() inputs must be array_api arrays or dtypes")

# promote python scalars given the result_type for all arrays/dtypes
from ._creation_functions import empty
arr = empty(1, dtype=result)
for s in scalars:
x = arr._promote_scalar(s)
result = _result_type(x.dtype, result)

return result
23 changes: 21 additions & 2 deletions array_api_strict/tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import numpy as np

from .._creation_functions import asarray
from .._data_type_functions import astype, can_cast, isdtype
from .._data_type_functions import astype, can_cast, isdtype, result_type
from .._dtypes import (
bool, int8, int16, uint8, float64,
bool, int8, int16, uint8, float64, int64
)
from .._flags import set_array_api_strict_flags

Expand Down Expand Up @@ -70,3 +70,22 @@ def astype_device(api_version):
else:
pytest.raises(TypeError, lambda: astype(a, int8, device=None))
pytest.raises(TypeError, lambda: astype(a, int8, device=a.device))


@pytest.mark.parametrize("api_version", ['2023.12', '2024.12'])
def test_result_type_py_scalars(api_version):
if api_version <= '2023.12':
set_array_api_strict_flags(api_version=api_version)

with pytest.raises(TypeError):
result_type(int16, 3)
else:
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version=api_version)

assert result_type(int8, 3) == int8
assert result_type(uint8, 3) == uint8
assert result_type(float64, 3) == float64

with pytest.raises(TypeError):
result_type(int64, True)
2 changes: 1 addition & 1 deletion array_api_strict/tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_concat_errors():
assert_raises(TypeError, lambda: concat((1, 1), axis=None))
assert_raises((TypeError, ValueError), lambda: concat((1, 1), axis=None))
assert_raises(TypeError, lambda: concat([asarray([1], dtype=int8),
asarray([1], dtype=float64)]))

Expand Down
2 changes: 1 addition & 1 deletion array_api_strict/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def p(func: Callable, *args, **kwargs):
[
p(xp.can_cast, 42, xp.int8),
p(xp.can_cast, xp.int8, 42),
p(xp.result_type, 42),
p(xp.result_type, "42"),
],
)
def test_raises_on_invalid_types(func, args, kwargs):
Expand Down
Loading