Skip to content

Commit

Permalink
Merge pull request #26377 from mattjj:maintain-mutable-array-sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724405629
  • Loading branch information
Google-ML-Automation committed Feb 7, 2025
2 parents ec47763 + 719031c commit c0ba362
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
7 changes: 3 additions & 4 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,7 @@ def __init__(self, aval, buf):
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
sharding = property(lambda self: self._buf.sharding)
def __getitem__(self, idx): return self._aval._getitem(self, idx)
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
Expand All @@ -2091,10 +2092,8 @@ def mutable_array_abstract_eval(init_aval):
@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
aval = get_aval(init_val)
# TODO(mattjj): improve spelling of 'defensive copy' here, avoid circular dep
init_val = init_val.copy() if hasattr(init_val, 'copy') else init_val
return MutableArray(AbstractRef(aval), init_val)
from jax._src.lax.lax import _array_copy # pytype: disable=import-error
return MutableArray(AbstractRef(get_aval(init_val)), _array_copy(init_val))

def freeze(ref):
return freeze_p.bind(ref)
Expand Down
9 changes: 4 additions & 5 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@
from jax._src.sharding import Sharding as JSharding
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping,
array_mapping_to_axis_resources,
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue,
get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
Expand Down Expand Up @@ -2140,8 +2139,8 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts):
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut))
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) # TODO(mattjj)
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
out_layouts_ = iter(zip(out_shardings, out_layouts))
out_shardings, out_layouts = unzip2(
Expand Down
13 changes: 13 additions & 0 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax._src import core
from jax._src import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding, PartitionSpec as P
import jax.numpy as jnp

from jax._src.state.types import (RefEffect)
Expand Down Expand Up @@ -241,6 +242,18 @@ def test_defensive_copy(self):
_ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x))
x + 1 # don't crash

def test_sharding_persists(self):
mesh = jax.make_mesh((1,), ('i',))
x = jax.device_put(jnp.arange(2), NamedSharding(mesh, P('i')))
s = x.sharding
a = core.mutable_array(x)
self.assertEqual(s, a.sharding)
self.assertEqual(s, a[...].sharding)
f = jax.jit(lambda: a[...])
y = f()
self.assertEqual(s, a.sharding)
self.assertEqual(s, y.sharding)


@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):
Expand Down

0 comments on commit c0ba362

Please sign in to comment.