Skip to content

Commit

Permalink
Merge pull request #221 from ROCm/ci-upstream-sync-106_1
Browse files Browse the repository at this point in the history
CI: 02/04/25 upstream sync
  • Loading branch information
github-actions[bot] authored Feb 7, 2025
2 parents 0962b96 + c20bb5b commit 9fbc1c1
Show file tree
Hide file tree
Showing 79 changed files with 1,449 additions and 832 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ jobs:
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 tests examples
pytest -n 4 --tb=short --maxfail=20 tests examples
documentation:
Expand Down
74 changes: 73 additions & 1 deletion .github/workflows/tsan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ jobs:
repository: python/cpython
path: cpython
ref: "3.13"
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: numpy/numpy
path: numpy
submodules: true

- name: Restore cached CPython with TSAN
id: cache-cpython-tsan-restore
Expand All @@ -67,7 +72,7 @@ jobs:
# Create archive to be used with bazel as hermetic python:
cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan
- name: Save CPython with TSAN
- name: Save TSAN CPython
id: cache-cpython-tsan-save
if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
Expand All @@ -76,6 +81,73 @@ jobs:
./python-tsan.tgz
key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }}

- name: Get year & week number
id: get-date
run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT
shell: bash -l {0}

- name: Restore cached TSAN Numpy
id: cache-numpy-tsan-restore
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: |
./wheelhouse
key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }}

- name: Build TSAN Numpy wheel
if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true'
run: |
cd numpy
# If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz
if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then
echo "Extract cpython from python-tsan.tgz"
pushd .
ls ${GITHUB_WORKSPACE}/python-tsan.tgz
cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz
ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/
popd
fi
export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH
python3 -m pip install -r requirements/build_requirements.txt
# Make sure to install a compatible Cython version (master branch is best for now)
python3 -m pip install -U git+https://github.com/cython/cython
CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized
# Create simple index and copy the wheel
mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/numpy
numpy_whl_name=($(cd dist && ls numpy*.whl))
if [ -z "${numpy_whl_name}" ]; then exit 1; fi
echo "Built TSAN Numpy wheel: ${numpy_whl_name}"
cp dist/${numpy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/numpy
cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html
<!DOCTYPE html><html><body>
<a href="numpy">numpy></a></br>
</body></html>
EOF
cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/numpy/index.html
<!DOCTYPE html><html><body>
<a href="${numpy_whl_name}">${numpy_whl_name}</a></br>
</body></html>
EOF
- name: Save TSAN Numpy wheel
id: cache-numpy-tsan-save
if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: |
./wheelhouse
key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }}

- name: Build Jax and run tests
timeout-minutes: 120
env:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Added {func}`jax.random.multinomial`.

* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
Expand Down
3 changes: 2 additions & 1 deletion build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ flatbuffers
hypothesis
mpmath>=1.3
pillow>=10.4.0
portpicker
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
portpicker; python_version<"3.13"
pytest-xdist
wheel
rich
Expand Down
2 changes: 1 addition & 1 deletion docs/autodidax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
"to `bind`, and in particular we follow a handy internal convention: when we\n",
"call `bind`, we pass values representing array data as positional arguments,\n",
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
"and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This\n",
"calling convention simplifies some core logic (since e.g. instances of the\n",
"`Tracer` class to be defined below can only occur in positional arguments to\n",
"`bind`). The wrappers can also provide docstrings!\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/autodidax.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ The functions that user code calls, like `add` and `sin`, are just wrappers
around calls to `bind`. These wrappers let us control how arguments are passed
to `bind`, and in particular we follow a handy internal convention: when we
call `bind`, we pass values representing array data as positional arguments,
and we pass metadata like the `axis` argument to `sum_p` via keyword. This
and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
calling convention simplifies some core logic (since e.g. instances of the
`Tracer` class to be defined below can only occur in positional arguments to
`bind`). The wrappers can also provide docstrings!
Expand Down
2 changes: 1 addition & 1 deletion docs/autodidax.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def bind1(prim, *args, **params):
# around calls to `bind`. These wrappers let us control how arguments are passed
# to `bind`, and in particular we follow a handy internal convention: when we
# call `bind`, we pass values representing array data as positional arguments,
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
# and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
# calling convention simplifies some core logic (since e.g. instances of the
# `Tracer` class to be defined below can only occur in positional arguments to
# `bind`). The wrappers can also provide docstrings!
Expand Down
1 change: 1 addition & 0 deletions docs/jax.random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Random Samplers
logistic
lognormal
maxwell
multinomial
multivariate_normal
normal
orthogonal
Expand Down
30 changes: 30 additions & 0 deletions docs/persistent_compilation_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,36 @@ so it is important for the persistent cache to be in a shared file system (eg: N
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
in subsequent runs as a result of a compilation cache miss.

### Pre-compiling multi-node programs on single node

JAX can populate the compilation cache with compiled programs for multiple nodes
on a single node. Preparing the cache on a single node helps to decrease the costly
compilation time on a cluster. To compile and run multi-node programs on a single
node, users can create fake remote devices using
the `jax_mock_gpu_topology` configuration option.

For instance, the snippet below instructs JAX to mock a cluster with four
nodes, each node running eight processes with each process attached to one GPU.

```python
jax.config.update("jax_mock_gpu_topology", "4x8x1")
```

After populating the cache with this config, users can run the program
without recompilation on four nodes, eight processes per node,
one GPU per process.

Important notes:

* The process running the mocked program must have the same amount of GPUs
and the same GPU model as the nodes that would use the cache. For instance,
a mocked topology `8x4x2` must run in a process with two GPUs.

* When running programs with mocked topology, the results of communications
with other nodes are undefined, so the outputs of JAX programs running
in mocked environments will likely be incorrect.


## Logging cache activity

It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
Expand Down
15 changes: 10 additions & 5 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,28 @@


def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
raise ValueError(
"numpy masked arrays are not supported as direct inputs to JAX functions."
" Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error


def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype),
sharding=core.get_cur_mesh_sharding(core.P(*[None] * x.ndim)))

core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array


def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
shape = np.shape(x)
return ShapedArray(shape, dtypes.canonicalize_dtype(dtype),
sharding=core.get_cur_mesh_sharding(core.P(*[None] * len(shape))))

for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
Expand All @@ -74,7 +78,8 @@ def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
weak_type=typ is not bool,
sharding=core.get_cur_mesh_sharding())

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def foo(x, y):
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = api_util.tracing_debug_info(
debug = api_util.debug_info(
"checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args)
Expand Down Expand Up @@ -418,7 +418,7 @@ def new_fun(*dyn_args, **kwargs):
def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
try:
Expand Down Expand Up @@ -447,7 +447,7 @@ def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)

debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
Expand Down
18 changes: 9 additions & 9 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes)
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -452,7 +452,7 @@ def value_and_grad_f(*args, **kwargs):
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
dbg = debug_info('value_and_grad', fun, args, kwargs)

f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
Expand Down Expand Up @@ -1021,7 +1021,7 @@ def _get_spec(arg, i):
try:
# Duck type arrays like BCOO arrays can be passed to vmap.
return shaped_abstractify(arg).sharding.spec[i]
except TypeError:
except (IndexError, TypeError):
return None

temp_spec = None
Expand Down Expand Up @@ -1426,7 +1426,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

dbg = tracing_debug_info(
dbg = debug_info(
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)

Expand Down Expand Up @@ -2235,7 +2235,7 @@ def _check_sharding(aval, s):
f" invalid value: {s}")
if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
aval = core.get_token_aval()
if not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
Expand Down
24 changes: 11 additions & 13 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable, safe_zip)
from jax._src import traceback_util
Expand Down Expand Up @@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
return fun


def tracing_debug_info(
def debug_info(
traced_for: str,
fun: Callable,
args: Sequence[Any],
Expand All @@ -594,14 +593,14 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> core.DebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand All @@ -619,7 +618,7 @@ def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
# See DebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
Expand Down Expand Up @@ -684,20 +683,19 @@ def result_paths(_fun, _store, *args, **kwargs):

# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
debug: core.DebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
if debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)
if result_paths is not None:
debug = debug._replace(result_paths=tuple(result_paths))
else:
debug = debug.resolve_result_paths()
return jaxpr.replace(debug_info=debug)

def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def _token_shard_arg(xs, shardings, layouts, copy_semantics):

def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.token_shaped_array, out_sharding, committed)
core.get_token_aval(), out_sharding, committed)

def wrapper(*args, **kwargs):
out_buf = array_handler(*args, **kwargs)
Expand Down
Loading

0 comments on commit 9fbc1c1

Please sign in to comment.