diff --git a/jax/_src/core.py b/jax/_src/core.py index 9ca86f00730e..e1b63821aba0 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2068,6 +2068,7 @@ def __init__(self, aval, buf): aval = property(lambda self: self._aval) shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) + sharding = property(lambda self: self._buf.sharding) def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) @@ -2091,10 +2092,8 @@ def mutable_array_abstract_eval(init_aval): @mutable_array_p.def_impl def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # pytype: disable=import-error - aval = get_aval(init_val) - # TODO(mattjj): improve spelling of 'defensive copy' here, avoid circular dep - init_val = init_val.copy() if hasattr(init_val, 'copy') else init_val - return MutableArray(AbstractRef(aval), init_val) + from jax._src.lax.lax import _array_copy # pytype: disable=import-error + return MutableArray(AbstractRef(get_aval(init_val)), _array_copy(init_val)) def freeze(ref): return freeze_p.bind(ref) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 544bf1a07e6e..7cbeeb080d90 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -65,9 +65,8 @@ from jax._src.sharding import Sharding as JSharding from jax._src.mesh import AbstractMesh, Mesh from jax._src.sharding_impls import ( - ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, - UnspecifiedValue, get_array_mapping as _get_array_mapping, - array_mapping_to_axis_resources, + ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, + get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources, SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, @@ -2140,8 +2139,8 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts): if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) - in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut) - in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) + in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut)) + in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) # TODO(mattjj) donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) out_layouts_ = iter(zip(out_shardings, out_layouts)) out_shardings, out_layouts = unzip2( diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 81f846ec63d3..7a1d36317c5b 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -21,6 +21,7 @@ from jax._src import core from jax._src import config from jax._src import test_util as jtu +from jax.sharding import NamedSharding, PartitionSpec as P import jax.numpy as jnp from jax._src.state.types import (RefEffect) @@ -241,6 +242,18 @@ def test_defensive_copy(self): _ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x)) x + 1 # don't crash + def test_sharding_persists(self): + mesh = jax.make_mesh((1,), ('i',)) + x = jax.device_put(jnp.arange(2), NamedSharding(mesh, P('i'))) + s = x.sharding + a = core.mutable_array(x) + self.assertEqual(s, a.sharding) + self.assertEqual(s, a[...].sharding) + f = jax.jit(lambda: a[...]) + y = f() + self.assertEqual(s, a.sharding) + self.assertEqual(s, y.sharding) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase):