Skip to content

Commit

Permalink
Merge pull request #135 from ROCm/ci-upstream-sync-13_1
Browse files Browse the repository at this point in the history
CI: 11/08/24 upstream sync
  • Loading branch information
charleshofer authored Nov 12, 2024
2 parents 9afbd23 + ced1e2b commit 0b970b8
Show file tree
Hide file tree
Showing 33 changed files with 978 additions and 174 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.

## jax 0.4.35 (Oct 22, 2024)

Expand Down
9 changes: 8 additions & 1 deletion docs/persistent_compilation_cache.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# Persistent compilation cache

<!--* freshness: { reviewed: '2024-04-09' } *-->
<!--* freshness: { reviewed: '2024-11-07' } *-->

JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.

Note: if the compilation cache is not on a local filesystem,
[etils](https://pypi.org/project/etils/) needs to be installed.

```python
pip install etils
```

## Usage

### Quick start
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,7 @@ def cache_miss(*args, **kwargs):

cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
lambda x, s: pxla.shard_args([s], [None], [x])[0],
lambda x, s: pxla.shard_args([s], [None], [None], [x])[0],
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)

Expand Down
23 changes: 16 additions & 7 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding,
Expand Down Expand Up @@ -1110,7 +1111,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[hashed_index(idx)]
if not candidates_list:
return pxla.shard_args([sharding], [None], [x._value],
return pxla.shard_args([sharding], [None], [None], [x._value],
canonicalize=False)[0]
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
Expand All @@ -1130,11 +1131,13 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
return dst_indices, tuple(src_indices) == tuple(dst_indices)


def _array_shard_arg(xs, shardings, layouts):
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
batch_cs = []

for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
for i, (x, sharding, layout, cs) in enumerate(
safe_zip(xs, shardings, layouts, copy_semantics)):
x._check_if_deleted()
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
same_layout = (True if layout is None else
Expand All @@ -1156,6 +1159,7 @@ def _array_shard_arg(xs, shardings, layouts):
batch_devs.append(list(devices))
batch_shardings.append(sharding)
batch_indices.append(i)
batch_cs.append(cs)
# Resharding starts here:
elif not same_layout:
results.append(api.device_put(x, Layout(layout, sharding)))
Expand All @@ -1165,8 +1169,12 @@ def _array_shard_arg(xs, shardings, layouts):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))

copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings)
if xla_extension_version >= 296:
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
else:
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # type: ignore
batch_xs, batch_devs, batch_shardings)
for i, copy_out in safe_zip(batch_indices, copy_outs):
assert results[i] is None
results[i] = copy_out
Expand Down Expand Up @@ -1200,8 +1208,9 @@ def _array_local_result_handler(aval, sharding, indices):

# Token handlers

def _token_shard_arg(xs, shardings, layouts):
return _array_shard_arg([x._buf for x in xs], shardings, layouts)
def _token_shard_arg(xs, shardings, layouts, copy_semantics):
return _array_shard_arg([x._buf for x in xs], shardings, layouts,
copy_semantics)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg


Expand Down
19 changes: 13 additions & 6 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_token_input(
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0])
self.current_tokens[eff] = sharded_tok
return sharded_tok

Expand Down Expand Up @@ -391,6 +391,7 @@ class _DeferredShardArg:
s: Sharding
aval: core.AbstractValue
committed: bool
copy_semantics: CopySemantics

@property
def result_handler(self):
Expand Down Expand Up @@ -435,24 +436,27 @@ def _device_put_sharding_impl(x, aval, device, copy):
"device_put's second argument must be a Device or a Sharding which"
f" represents addressable devices, but got {s}. Please pass device or"
" Sharding which represents addressable devices.")
return _DeferredShardArg(x, s, aval, True)
return _DeferredShardArg(x, s, aval, True, copy)

# Only `Device` exists below. `Sharding` instance is handled above.
if isinstance(x, array.ArrayImpl):
if not x.is_fully_addressable:
raise ValueError(
"device_put's first argument must be a fully addressable array, but "
f"got value with devices {x.devices()}")
if device is None and copy == CopySemantics.ALIAS:
return x
if device is None:
if copy == CopySemantics.ALIAS:
return x
else:
return _DeferredShardArg(x, x.sharding, aval, x.committed, copy)
elif is_single_device_sharding(x.sharding):
device = x.sharding._device_assignment[0] if device is None else device
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
[device])

sh = SingleDeviceSharding(pxla._get_default_device()
if device is None else device)
return _DeferredShardArg(x, sh, aval, device is not None)
return _DeferredShardArg(x, sh, aval, device is not None, copy)


def _device_put_impl(
Expand Down Expand Up @@ -501,12 +505,14 @@ def _batched_device_put_impl(
copy_semantics: Sequence[CopySemantics]):
ys = []
shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], []
shard_arg_copy_semantics = []
for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)):
y = _device_put_impl(x, device=device, src=src, copy=cp)
if isinstance(y, _DeferredShardArg):
shard_arg_indices.append(i)
shard_arg_xs.append(y.x)
shard_arg_shardings.append(y.s)
shard_arg_copy_semantics.append(y.copy_semantics)
ys.append(y)

if shard_arg_xs:
Expand All @@ -515,7 +521,8 @@ def _batched_device_put_impl(
# device_put handles `Layout` via a different path, so just pass `None` as
# the layout here.
shard_arg_results = pxla.shard_args(
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
shard_arg_shardings, [None] * len(shard_arg_xs),
shard_arg_copy_semantics, shard_arg_xs)
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
assert isinstance(ys[i], _DeferredShardArg)
ys[i] = ys[i].result_handler(shard_arg_result)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def global_shards(self):

# TODO(mattjj): _set_array_base_attributes

def _earray_shard_arg_handler(xs, shardings, layouts):
def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
arrs = [x._data for x in xs]
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
for x, sharding in zip(xs, shardings)]
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
return pxla.shard_args(phys_shardings, layouts, arrs)
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler

api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
Expand Down
63 changes: 45 additions & 18 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -105,44 +106,69 @@ class WeakRefList(list):

### util


def to_xc_copy_semantics(copy_semantics):
if xla_extension_version < 296:
return [None] * len(copy_semantics)
out = []
for cs in copy_semantics:
if cs is None or cs == dispatch.CopySemantics.ALIAS:
out.append(xc.ArrayCopySemantics.REUSE_INPUT)
elif cs == dispatch.CopySemantics.COPY:
out.append(xc.ArrayCopySemantics.ALWAYS_COPY)
elif cs == dispatch.CopySemantics.DONATE:
out.append(xc.ArrayCopySemantics.DONATE_INPUT)
else:
assert isinstance(cs, xc.ArrayCopySemantics)
out.append(cs)
return out


def identity(x): return x

@profiler.annotate_function
def shard_args(shardings: Sequence[JSharding], layouts, args,
canonicalize=True) -> Sequence[xc.ArrayImpl]:
def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics,
args, canonicalize=True) -> Sequence[xc.ArrayImpl]:
xc_copy_semantics = to_xc_copy_semantics(copy_semantics)
del copy_semantics
# Fast path for one argument.
if len(args) == 1:
arg = args[0]
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)]([arg], shardings, layouts)

# type(arg) -> (list[indices], list[args], list[shardings])
batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore
for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)):
return shard_arg_handlers[type(arg)]([arg], shardings, layouts,
xc_copy_semantics)

# type(arg) -> (list[indices], list[args], list[shardings], list[layouts],
# list[copy_semantics])
batches = collections.defaultdict(lambda: ([], [], [], [], [])) # type: ignore
for i, (arg, sharding, layout, cs) in enumerate(
safe_zip(args, shardings, layouts, xc_copy_semantics)):
if canonicalize:
arg = xla.canonicalize_dtype(arg)
batch = batches[type(arg)]
batch[0].append(i)
batch[1].append(arg)
batch[2].append(sharding)
batch[3].append(layout)
batch[4].append(cs)

# Call `shard_arg_handlers` per batch and build a flat list of arrays returned
# from each call in the same order as `args`. Since `batches` is grouped by
# types, we cannot simply flatten the results and we have to use the original
# indices to put each array back to its original position.
results: list[jax.Array | None] = [None] * len(args)
for t, (indices, a, s, l) in batches.items():
outs = shard_arg_handlers[t](a, s, l)
for t, (indices, a, s, l, cs) in batches.items():
outs = shard_arg_handlers[t](a, s, l, cs)
for i, out in safe_zip(indices, outs):
results[i] = out
assert all(result is not None for result in results)
return results


shard_arg_handlers: dict[
Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]]
Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any], Sequence[Any]],
Sequence[Any]]
] = {}


Expand Down Expand Up @@ -172,12 +198,12 @@ def is_default_layout(curr_layout, sharding, aval):
raise


def _masked_array_error(xs, shardings, layouts):
def _masked_array_error(xs, shardings, layouts, copy_semantics):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error

def _shard_np_array(xs, shardings, layouts):
def _shard_np_array(xs, shardings, layouts, copy_semantics):
results = []
for x, sharding, layout in safe_zip(xs, shardings, layouts):
devices = sharding._addressable_device_assignment
Expand All @@ -197,12 +223,12 @@ def _shard_np_array(xs, shardings, layouts):
for _t in array_types:
shard_arg_handlers[_t] = _shard_np_array

def _shard_darray(xs, shardings, layouts):
return shard_args(shardings, layouts, [x._data for x in xs])
def _shard_darray(xs, shardings, layouts, copy_semantics):
return shard_args(shardings, layouts, copy_semantics, [x._data for x in xs])
shard_arg_handlers[core.DArray] = _shard_darray

def _shard_mutable_array(xs, shardings, layouts):
return shard_args(shardings, layouts, [x._buf for x in xs])
def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
return shard_args(shardings, layouts, copy_semantics, [x._buf for x in xs])
shard_arg_handlers[core.MutableArray] = _shard_mutable_array

def batched_device_put(aval: core.ShapedArray,
Expand Down Expand Up @@ -1135,7 +1161,8 @@ class InputsHandler:

def __init__(self, in_shardings, in_layouts, local_devices=None,
input_indices=None):
self.handler = partial(shard_args, in_shardings, in_layouts)
self.handler = partial(shard_args, in_shardings, in_layouts,
[None] * len(in_shardings))
self.in_shardings = in_shardings
self.in_layouts = in_layouts
self.local_devices = local_devices
Expand Down Expand Up @@ -3047,7 +3074,7 @@ def aot_cache_miss(*args, **kwargs):
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)

def cc_shard_arg(x, sharding, layout):
return shard_args([sharding], [layout], [x])[0]
return shard_args([sharding], [layout], [None], [x])[0]


def check_arg_avals_for_call(ref_avals, arg_avals,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def _maybe_put(x):
aval = shaped_abstractify(x)
s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0])
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
return result_handler(pxla.shard_args([s], [None], [x]))
return result_handler(pxla.shard_args([s], [None], [None], [x]))
else:
return x

Expand Down
Loading

0 comments on commit 0b970b8

Please sign in to comment.