Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jnp.tile for array API 2023 #20943

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Remember to align the itemized text with the first line of an item within a list
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* Passing `A` and `reps` to {func}`jax.numpy.tile` is now deprecated. Instead,
pass them as positional arguments only.

* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Expand Down
64 changes: 56 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,20 +1897,68 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
)
return tuple(moveaxis(x, axis, 0))

@util.implements(np.tile)
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
util.check_arraylike("tile", A)
@util.implements(np.tile,
extra_params=_dedent("""
x : array_like
Array containing elements to tile.
repetitions : array_like
The number of repetitions along each axis (dimension).
""")
)
def tile(
x: ArrayLike | DeprecatedArg = DeprecatedArg(),
repetitions: DimSize | Sequence[DimSize] | DeprecatedArg = DeprecatedArg(),
/, A: DeprecatedArg = DeprecatedArg(), reps: DeprecatedArg = DeprecatedArg()
) -> Array:

# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
if not isinstance(A, DeprecatedArg):
if not isinstance(x, DeprecatedArg):
raise ValueError(
"Only one of `A` and `x` (the first positional argument) should be "
"provided."
)
warnings.warn(
"Passing the `A` argument to jnp.tile by keyword is deprecated. Pass the "
"value as the first positional argument instead.",
DeprecationWarning, stacklevel=2
)
x = A
del A
if not isinstance(reps, DeprecatedArg):
if not isinstance(repetitions, DeprecatedArg):
raise ValueError(
"Only one of `reps` and `repetitions` (the second positional argument) "
"should be provided."
)
warnings.warn(
"Passing the `reps` argument to jnp.tile by keyword is deprecated. Pass the "
"value as the second positional argument instead.",
DeprecationWarning, stacklevel=2
)
repetitions = reps
del reps
if isinstance(x, DeprecatedArg):
raise ValueError(
"jnp.tile is missing a required positional argument: `x`."
)
if isinstance(reps, DeprecatedArg):
raise ValueError(
"jnp.tile is missing a required positional argument: `repetitions`."
)

util.check_arraylike("tile", x)
try:
iter(reps) # type: ignore[arg-type]
iter(repetitions) # type: ignore[arg-type]
except TypeError:
reps_tup: tuple[DimSize, ...] = (reps,)
reps_tup: tuple[DimSize, ...] = (repetitions,)
else:
reps_tup = tuple(reps) # type: ignore[assignment,arg-type]
reps_tup = tuple(repetitions) # type: ignore[assignment,arg-type]
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
for rep in reps_tup)
A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A)
A_shape = (1,) * (len(reps_tup) - ndim(x)) + shape(x)
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
result = broadcast_to(reshape(x, [j for i in A_shape for j in [1, i]]),
[k for pair in zip(reps_tup, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps_tup)))

Expand Down
5 changes: 4 additions & 1 deletion jax/_src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ def shape(self) -> Shape: ...
# We use a class for deprecated args to avoid using Any/object types which can
# introduce complications and mistakes in static analysis
class DeprecatedArg:
def __init__(self, msg="Deprecated"):
self.msg = msg
return
def __repr__(self):
return "Deprecated"
return self.msg

# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
roll as roll,
squeeze as squeeze,
stack as stack,
tile as tile,
unstack as unstack,
)

Expand Down
6 changes: 6 additions & 0 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array
dtype = _result_type(*arrays)
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)


def tile(x: Array, repetitions: tuple[int], /) -> Array:
"""Constructs an array by tiling an input array."""
return jax.numpy.tile(x, repetitions)


def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
"""Splits an array in a sequence of arrays along the given axis."""
return jax.numpy.unstack(x, axis=axis)
7 changes: 6 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,12 @@ def tensordot(a: ArrayLike, b: ArrayLike,
axes: Union[int, Sequence[int], Sequence[Sequence[int]]] = ...,
*, precision: PrecisionLike = ...,
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...
def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ...
def tile(
A: ArrayLike | DeprecatedArg = ...,
repetitions: Union[DimSize, Sequence[DimSize]] | DeprecatedArg = ...,
x: DeprecatedArg = ...,
reps: DeprecatedArg = ...,
) -> Array: ...
def trace(a: ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...,
dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ...
def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ...
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
'tan',
'tanh',
'tensordot',
'tile',
'tril',
'triu',
'trunc',
Expand Down
Loading