We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Repro:
import os os.environ['JAX_PLATFORMS'] = 'cpu' os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4' import jax import jax.numpy as jnp from jax.experimental import mesh_utils from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax._src.core import mutable_array from jax._src.state.primitives import ref_swap devices = mesh_utils.create_device_mesh((2, 2)) mesh = Mesh(devices, axis_names=('i', 'j')) sharding = NamedSharding(mesh, P('i', 'j')) a = jnp.zeros_like(mesh.device_ids, dtype=jnp.int32) a = jax.make_array_from_callback(a.shape, sharding, lambda idx: a[idx]) print(a[...].sharding) a_ref = mutable_array(a) print(a_ref[...].sharding)
Expected output: The two lines should both be NamedSharding and should be the same.
NamedSharding
Actual output:
NamedSharding(mesh=Mesh('i': 2, 'j': 2), spec=PartitionSpec('i', 'j'), memory_kind=unpinned_host) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
jax: 0.5.0 jaxlib: 0.5.0 numpy: 2.2.2 python: 3.13.0rc3 (main, Oct 2 2024, 17:18:08) [Clang 18.1.8 ] device info: cpu-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='ayx1', release='6.10.11-1rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.11-1rodete2 (2024-10-16)', machine='x86_64')
The text was updated successfully, but these errors were encountered:
Note that it works under jit:
import os os.environ['JAX_PLATFORMS'] = 'cpu' os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4' import jax import jax.numpy as jnp from jax.experimental import mesh_utils from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax._src.core import mutable_array @jax.jit def f(a): jax.debug.visualize_array_sharding(a) a_ref = mutable_array(a) jax.debug.visualize_array_sharding(a_ref[...]) devices = mesh_utils.create_device_mesh((2, 2)) mesh = Mesh(devices, axis_names=('i', 'j')) sharding = NamedSharding(mesh, P('i', 'j')) a = jnp.zeros_like(mesh.device_ids, dtype=jnp.int32) a = jax.make_array_from_callback(a.shape, sharding, lambda idx: a[idx]) f(a)
Output:
┌──────────┬──────────┐ │ │ │ │ CPU 0 │ CPU 1 │ │ │ │ │ │ │ ├──────────┼──────────┤ │ │ │ │ CPU 2 │ CPU 3 │ │ │ │ │ │ │ └──────────┴──────────┘ ┌──────────┬──────────┐ │ │ │ │ CPU 0 │ CPU 1 │ │ │ │ │ │ │ ├──────────┼──────────┤ │ │ │ │ CPU 2 │ CPU 3 │ │ │ │ │ │ │ └──────────┴──────────┘
Sorry, something went wrong.
But should we allow mutable_array inside jit?
mutable_array
jit
Related: #26349
Using mutable_array inside a jit is fine; see the mutable array tests. We just can't return one from a jit.
mattjj
dougalm
Successfully merging a pull request may close this issue.
Description
Repro:
Expected output: The two lines should both be
NamedSharding
and should be the same.Actual output:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: