Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating a mutable array does not preserve its sharding #26338

Closed
ayaka14732 opened this issue Feb 5, 2025 · 3 comments · Fixed by #26377
Closed

Creating a mutable array does not preserve its sharding #26338

ayaka14732 opened this issue Feb 5, 2025 · 3 comments · Fixed by #26377
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

Repro:

import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src.core import mutable_array
from jax._src.state.primitives import ref_swap

devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
sharding = NamedSharding(mesh, P('i', 'j'))

a = jnp.zeros_like(mesh.device_ids, dtype=jnp.int32)
a = jax.make_array_from_callback(a.shape, sharding, lambda idx: a[idx])
print(a[...].sharding)

a_ref = mutable_array(a)
print(a_ref[...].sharding)

Expected output: The two lines should both be NamedSharding and should be the same.

Actual output:

NamedSharding(mesh=Mesh('i': 2, 'j': 2), spec=PartitionSpec('i', 'j'), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.13.0rc3 (main, Oct  2 2024, 17:18:08) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.10.11-1rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.11-1rodete2 (2024-10-16)', machine='x86_64')
@ayaka14732 ayaka14732 added the bug Something isn't working label Feb 5, 2025
@ayaka14732
Copy link
Member Author

Note that it works under jit:

import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src.core import mutable_array

@jax.jit
def f(a):
    jax.debug.visualize_array_sharding(a)

    a_ref = mutable_array(a)
    jax.debug.visualize_array_sharding(a_ref[...])

devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
sharding = NamedSharding(mesh, P('i', 'j'))

a = jnp.zeros_like(mesh.device_ids, dtype=jnp.int32)
a = jax.make_array_from_callback(a.shape, sharding, lambda idx: a[idx])

f(a)

Output:

┌──────────┬──────────┐
│          │          │
│  CPU 0   │  CPU 1   │
│          │          │
│          │          │
├──────────┼──────────┤
│          │          │
│  CPU 2   │  CPU 3   │
│          │          │
│          │          │
└──────────┴──────────┘
┌──────────┬──────────┐
│          │          │
│  CPU 0   │  CPU 1   │
│          │          │
│          │          │
├──────────┼──────────┤
│          │          │
│  CPU 2   │  CPU 3   │
│          │          │
│          │          │
└──────────┴──────────┘

@ayaka14732
Copy link
Member Author

But should we allow mutable_array inside jit?

Related: #26349

@mattjj
Copy link
Collaborator

mattjj commented Feb 6, 2025

Using mutable_array inside a jit is fine; see the mutable array tests. We just can't return one from a jit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants