From 68428488c8d1b62ac5cc9385d3d3623424dba4dd Mon Sep 17 00:00:00 2001 From: minigoel Date: Mon, 28 Oct 2024 10:47:59 -0700 Subject: [PATCH 001/112] Add a link to Intel plugin for JAX --- README.md | 2 ++ docs/installation.md | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/README.md b/README.md index c99d3db10a2a..26c0797db6b0 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,7 @@ Some standouts: | Google TPU | yes | n/a | n/a | n/a | n/a | n/a | | AMD GPU | yes | no | experimental | n/a | no | no | | Apple GPU | n/a | no | n/a | experimental | n/a | n/a | +| Intel GPU | experimental | n/a | n/a | n/a | no | no | ### Instructions @@ -401,6 +402,7 @@ Some standouts: | Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | | AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | +| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) for information on alternative installation strategies. These include compiling diff --git a/docs/installation.md b/docs/installation.md index 5b8893628d85..7cf64955722c 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -35,6 +35,7 @@ The table below shows all supported platforms and installation options. Check if | Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | | AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no | | Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | +| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | (install-cpu)= @@ -230,6 +231,17 @@ JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or * Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_). +(install-intel-gpu)= +## Intel GPU + +Intel provides an experimental OneAPI plugin: intel-extension-for-openxla for Intel GPU hardware. For more details and installation instructions, refer to one of the following two methods: +1. Pip installation: [JAX acceleration on Intel GPU](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). +2. Using [Intel's XLA Docker container](https://hub.docker.com/r/intel/intel-optimized-xla). + +Please report any issues related to: +* JAX: [JAX issue tracker](https://github.com/jax-ml/jax/issues). +* Intel's OpenXLA plugin: [Intel-extension-for-openxla issue tracker](https://github.com/intel/intel-extension-for-openxla/issues). + ## Conda (community-supported) ### Conda installation From 78da9fa4322bb62e4e1cc55977cb79fc20cb0ccb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Oct 2024 15:33:24 -0700 Subject: [PATCH 002/112] Add float8_e4m3 and float8_e3m4 types support --- jax/_src/dtypes.py | 19 +++++++++++++++++++ jax/_src/export/serialization.fbs | 2 ++ jax/_src/export/serialization.py | 4 ++++ jax/_src/export/serialization_generated.py | 2 ++ jax/_src/interpreters/mlir.py | 12 ++++++------ jax/_src/lax/lax.py | 12 ++++++++++-- jax/_src/numpy/lax_numpy.py | 4 ++++ jax/_src/public_test_util.py | 14 ++++++++++++++ jax/_src/test_util.py | 17 +++++++++++++---- jax/numpy/__init__.py | 9 +++++++++ tests/dtypes_test.py | 4 ++++ 11 files changed, 87 insertions(+), 12 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index ac0418932b83..c9710c287879 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ def type(self) -> type: ... # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) + _float8_dtypes.insert(0, _float8_e4m3_dtype) +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) + _float8_dtypes.insert(0, _float8_e3m4_dtype) + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 3198f83aa120..b71b377d8999 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -67,6 +67,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index e392289da64d..0d9ce961b556 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -359,6 +359,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 18dd2c3cbab1..70d298020961 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,6 +53,8 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2c0e26019e4d..54a85f92c873 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -184,13 +184,13 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.int2 is not None: assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( - ir.IntegerType.get_signless, 2 - ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( - ir.IntegerType.get_unsigned, 2 - ) + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8b6a517a54b3..7fa2dd4acbfa 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -937,11 +937,15 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + np.dtype(dtypes.float8_e5m2fnuz)] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3625,6 +3629,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + if dtypes.float8_e3m4 is not None: + fp8_dtypes += (dtypes.float8_e3m4,) + if dtypes.float8_e4m3 is not None: + fp8_dtypes += (dtypes.float8_e4m3,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d2e89833915d..c419c083f837 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -217,6 +217,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64cda2..6bbcdd08471f 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index bb81c979bc48..e7707f58fc4a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1431,10 +1431,19 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ - _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, - _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + float_dtypes = [ + _dtypes.bfloat16, + _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e4m3fn, + _dtypes.float8_e4m3fnuz, + _dtypes.float8_e5m2, + _dtypes.float8_e5m2fnuz, + ] + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return [np.dtype(t) for t in float_dtypes] @_cached_property def floating(self): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9be73e96adcf..9a643bf49bf0 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -273,6 +273,15 @@ except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a8f9..6c7e9e3ab712 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,10 @@ fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] +if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes From e6f6a8af8d2bd3bec601dfd029b06d2baecd6130 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 6 Nov 2024 10:43:17 -0800 Subject: [PATCH 003/112] Move Control Flow text from Sharp Bits into its own tutorial. --- README.md | 5 +- docs/control-flow.md | 361 ++++++++++++ docs/faq.rst | 2 +- docs/jit-compilation.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 607 +-------------------- docs/notebooks/Common_Gotchas_in_JAX.md | 328 +---------- docs/stateful-computations.md | 1 + docs/tutorials.rst | 1 + 8 files changed, 379 insertions(+), 928 deletions(-) create mode 100644 docs/control-flow.md diff --git a/README.md b/README.md index 89fe51212638..ce695dd6a26d 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` @@ -369,7 +368,7 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). + flow](https://jax.readthedocs.io/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), diff --git a/docs/control-flow.md b/docs/control-flow.md new file mode 100644 index 000000000000..04eb3cac8d24 --- /dev/null +++ b/docs/control-flow.md @@ -0,0 +1,361 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + ++++ {"id": "rg4CpMZ8c3ri"} + +(control-flow)= +# Control flow and logical operators with JIT + + + +When executing eagerly (outside of `jit`), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators with `jit` is more complicated. + +In a nutshell, Python control flow and logical operators are evaluated at JIT compile time, such that the compiled function represents a single path through the [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph) (logical operators affect the path via short-circuiting). If the path depends on the values of the inputs, the function (by default) cannot be JIT compiled. The path may depend on the shape or dtype of the inputs, and the function is re-compiled every time it is called on an input with a new shape or dtype. + +```{code-cell} +from jax import grad, jit +import jax.numpy as jnp +``` + +For example, this works: + +```{code-cell} +:id: OZ_BJX0CplNC +:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c + +@jit +def f(x): + for i in range(3): + x = 2 * x + return x + +print(f(3)) +``` + ++++ {"id": "22RzeJ4QqAuX"} + +So does this: + +```{code-cell} +:id: pinVnmRWp6w6 +:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422 + +@jit +def g(x): + y = 0. + for i in range(x.shape[0]): + y = y + x[i] + return y + +print(g(jnp.array([1., 2., 3.]))) +``` + ++++ {"id": "TStltU2dqf8A"} + +But this doesn't, at least by default: + +```{code-cell} +:id: 9z38AIKclRNM +:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac +:tags: [raises-exception] + +@jit +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +# This will fail! +f(2) +``` + +Neither does this: + +```{code-cell} +:tags: [raises-exception] + +@jit +def g(x): + return (x > 0) and (x < 3) + +# This will fail! +g(2) +``` + ++++ {"id": "pIbr4TVPqtDN"} + +__What gives!?__ + +When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation. + +For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. + +To get a view of your Python code that is valid for many different argument values, JAX traces it with the `ShapedArray` abstraction as input, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. + +But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace. + +The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnames` (or `static_argnums`) argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again: + +```{code-cell} +:id: -Tzp0H7Bt1Sn +:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a + +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +f = jit(f, static_argnames='x') + +print(f(2.)) +``` + ++++ {"id": "MHm1hIQAvBVs"} + +Here's another example, this time involving a loop: + +```{code-cell} +:id: iwY86_JKvD6b +:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937 + +def f(x, n): + y = 0. + for i in range(n): + y = y + x[i] + return y + +f = jit(f, static_argnames='n') + +f(jnp.array([2., 3., 4.]), 2) +``` + ++++ {"id": "nSPTOX8DvOeO"} + +In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation + ++++ {"id": "wWdg8LTYwCW3"} + +️⚠️ **functions with argument-__value__ dependent shapes** + +These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`. + +```{code-cell} +:id: Tqe9uLmUI_Gv +:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883 + +def example_fun(length, val): + return jnp.ones((length,)) * val +# un-jit'd works fine +print(example_fun(5, 4)) +``` + +```{code-cell} +:id: fOlR54XRgHpd +:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71 +:tags: [raises-exception] + +bad_example_jit = jit(example_fun) +# this will fail: +bad_example_jit(10, 4) +``` + +```{code-cell} +:id: kH0lOD4GgFyI +:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade + +# static_argnames tells JAX to recompile on changes at these argument positions: +good_example_jit = jit(example_fun, static_argnames='length') +# first compile +print(good_example_jit(10, 4)) +# recompiles +print(good_example_jit(5, 4)) +``` + ++++ {"id": "MStx_r2oKxpp"} + +`static_argnames` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! + +Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: + +```{code-cell} +:id: m2ABpRd8K094 +:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3 + +@jit +def f(x): + print(x) + y = 2 * x + print(y) + return y +f(2) +``` + ++++ {"id": "uCDcWG4MnVn-"} + +## Structured control flow primitives + +There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives: + + - `lax.cond` _differentiable_ + - `lax.while_loop` __fwd-mode-differentiable__ + - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static. + - `lax.scan` _differentiable_ + ++++ {"id": "Sd9xrLMXeK3A"} + +### `cond` +python equivalent: + +```python +def cond(pred, true_fun, false_fun, operand): + if pred: + return true_fun(operand) + else: + return false_fun(operand) +``` + +```{code-cell} +:id: SGxz9JOWeiyH +:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3 + +from jax import lax + +operand = jnp.array([0.]) +lax.cond(True, lambda x: x+1, lambda x: x-1, operand) +# --> array([1.], dtype=float32) +lax.cond(False, lambda x: x+1, lambda x: x-1, operand) +# --> array([-1.], dtype=float32) +``` + ++++ {"id": "lIYdn1woOS1n"} + +`jax.lax` provides two other functions that allow branching on dynamic predicates: + +- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is + like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays + rather than as functions. +- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is + like `lax.cond`, but allows switching between any number of callable choices. + +In addition, `jax.numpy` provides several numpy-style interfaces to these functions: + +- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with + three arguments is the numpy-style wrapper of `lax.select`. +- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) + is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. +- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has + an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather + than as functions. It is implemented in terms of multiple calls to `lax.select`. + ++++ {"id": "xkOFAw24eOMg"} + +### `while_loop` + +python equivalent: +``` +def while_loop(cond_fun, body_fun, init_val): + val = init_val + while cond_fun(val): + val = body_fun(val) + return val +``` + +```{code-cell} +:id: jM-D39a-c436 +:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e + +init_val = 0 +cond_fun = lambda x: x < 10 +body_fun = lambda x: x+1 +lax.while_loop(cond_fun, body_fun, init_val) +# --> array(10, dtype=int32) +``` + ++++ {"id": "apo3n3HAeQY_"} + +### `fori_loop` +python equivalent: +``` +def fori_loop(start, stop, body_fun, init_val): + val = init_val + for i in range(start, stop): + val = body_fun(i, val) + return val +``` + +```{code-cell} +:id: dt3tUpOmeR8u +:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380 + +init_val = 0 +start = 0 +stop = 10 +body_fun = lambda i,x: x+i +lax.fori_loop(start, stop, body_fun, init_val) +# --> array(45, dtype=int32) +``` + ++++ {"id": "SipXS5qiqk8e"} + +### Summary + +$$ +\begin{array} {r|rr} +\hline \ +\textrm{construct} +& \textrm{jit} +& \textrm{grad} \\ +\hline \ +\textrm{if} & ❌ & ✔ \\ +\textrm{for} & ✔* & ✔\\ +\textrm{while} & ✔* & ✔\\ +\textrm{lax.cond} & ✔ & ✔\\ +\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ +\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ +\textrm{lax.scan} & ✔ & ✔\\ +\hline +\end{array} +$$ + +
+ +$\ast$ = argument-value-independent loop condition - unrolls the loop + +
+ +## Logical operators + +`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`. + ++++ {"id": "izLTvT24dAq0"} + +## Python control flow + autodiff + +Remember that the above constraints on control flow and logical operators are relevant only with `jit`. If you just want to apply `grad` to your python functions, without `jit`, you can use regular Python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). + +```{code-cell} +:id: aAx0T3F8lLtu +:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52 + +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +print(grad(f)(2.)) # ok! +print(grad(f)(4.)) # ok! +``` diff --git a/docs/faq.rst b/docs/faq.rst index 1d2bb204f24c..44267f6f5f7d 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 51322fda9476..5e5be308068a 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -170,7 +170,7 @@ jax.jit(g)(10, 20) # Raises an error The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values. Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as `shape` or `dtype`, and not via their values. -For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). +For more detail on the interaction between Python control flow and JAX, see {ref}`control-flow`. One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical. In that case, you can consider JIT-compiling only part of the function. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 71bd4527644a..92c736957db6 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -34,7 +34,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "from jax import grad, jit\n", + "from jax import jit\n", "from jax import lax\n", "from jax import random\n", "import jax\n", @@ -1175,610 +1175,14 @@ }, { "cell_type": "markdown", + "id": "1dc0e6b2", "metadata": { "id": "rg4CpMZ8c3ri" }, "source": [ - "## 🔪 Control flow" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "izLTvT24dAq0" - }, - "source": [ - "### ✔ Python control_flow + autodiff ✔\n", - "\n", - "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "id": "aAx0T3F8lLtu", - "outputId": "383b7bfa-1634-4d23-8497-49cb9452ca52" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.0\n", - "-4.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "print(grad(f)(2.)) # ok!\n", - "print(grad(f)(4.)) # ok!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hIfPT7WMmZ2H" - }, - "source": [ - "### Python control flow + JIT\n", - "\n", - "Using control flow with `jit` is more complicated, and by default it has more constraints.\n", - "\n", - "This works:" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "OZ_BJX0CplNC", - "outputId": "60c902a2-eba1-49d7-c8c8-2f68616d660c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "24\n" - ] - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " for i in range(3):\n", - " x = 2 * x\n", - " return x\n", - "\n", - "print(f(3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "22RzeJ4QqAuX" - }, - "source": [ - "So does this:" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "id": "pinVnmRWp6w6", - "outputId": "25e06cf2-474f-4782-af7c-4f5514b64422" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "6.0\n" - ] - } - ], - "source": [ - "@jit\n", - "def g(x):\n", - " y = 0.\n", - " for i in range(x.shape[0]):\n", - " y = y + x[i]\n", - " return y\n", - "\n", - "print(g(jnp.array([1., 2., 3.])))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TStltU2dqf8A" - }, - "source": [ - "But this doesn't, at least by default:" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "id": "9z38AIKclRNM", - "outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ConcretizationTypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n" - ] - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "# This will fail!\n", - "f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pIbr4TVPqtDN" - }, - "source": [ - "__What gives!?__\n", - "\n", - "When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n", - "\n", - "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", - "\n", - "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", - "\n", - "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", - "\n", - "But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n", - "\n", - "The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "id": "-Tzp0H7Bt1Sn", - "outputId": "f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "f = jit(f, static_argnums=(0,))\n", - "\n", - "print(f(2.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MHm1hIQAvBVs" - }, - "source": [ - "Here's another example, this time involving a loop:" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "id": "iwY86_JKvD6b", - "outputId": "48f9b51f-bd32-466f-eac1-cd23444ce937" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(5., dtype=float32)" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f(x, n):\n", - " y = 0.\n", - " for i in range(n):\n", - " y = y + x[i]\n", - " return y\n", - "\n", - "f = jit(f, static_argnums=(1,))\n", - "\n", - "f(jnp.array([2., 3., 4.]), 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nSPTOX8DvOeO" - }, - "source": [ - "In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wWdg8LTYwCW3" - }, - "source": [ - "️⚠️ **functions with argument-__value__ dependent shapes**\n", - "\n", - "These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "id": "Tqe9uLmUI_Gv", - "outputId": "989be121-dfce-4bb3-c78e-a10829c5f883" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[4. 4. 4. 4. 4.]\n" - ] - } - ], - "source": [ - "def example_fun(length, val):\n", - " return jnp.ones((length,)) * val\n", - "# un-jit'd works fine\n", - "print(example_fun(5, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "id": "fOlR54XRgHpd", - "outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Tracedwith,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n" - ] - } - ], - "source": [ - "bad_example_jit = jit(example_fun)\n", - "# this will fail:\n", - "bad_example_jit(10, 4)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": { - "id": "kH0lOD4GgFyI", - "outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n", - "[4. 4. 4. 4. 4.]\n" - ] - } - ], - "source": [ - "# static_argnums tells JAX to recompile on changes at these argument positions:\n", - "good_example_jit = jit(example_fun, static_argnums=(0,))\n", - "# first compile\n", - "print(good_example_jit(10, 4))\n", - "# recompiles\n", - "print(good_example_jit(5, 4))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MStx_r2oKxpp" - }, - "source": [ - "`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n", - "\n", - "Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "id": "m2ABpRd8K094", - "outputId": "4f7ebe17-ade4-4e18-bd8c-4b24087c33c3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tracedwith\n", - "Tracedwith\n" - ] - }, - { - "data": { - "text/plain": [ - "Array(4, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " print(x)\n", - " y = 2 * x\n", - " print(y)\n", - " return y\n", - "f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uCDcWG4MnVn-" - }, - "source": [ - "### Structured control flow primitives\n", - "\n", - "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n", - "\n", - " - `lax.cond` _differentiable_\n", - " - `lax.while_loop` __fwd-mode-differentiable__\n", - " - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.\n", - " - `lax.scan` _differentiable_" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sd9xrLMXeK3A" - }, - "source": [ - "#### `cond`\n", - "python equivalent:\n", - "\n", - "```python\n", - "def cond(pred, true_fun, false_fun, operand):\n", - " if pred:\n", - " return true_fun(operand)\n", - " else:\n", - " return false_fun(operand)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": { - "id": "SGxz9JOWeiyH", - "outputId": "942a8d0e-5ff6-4702-c499-b3941f529ca3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([-1.], dtype=float32)" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import lax\n", - "\n", - "operand = jnp.array([0.])\n", - "lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n", - "# --> array([1.], dtype=float32)\n", - "lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n", - "# --> array([-1.], dtype=float32)" - ] - }, - { - "cell_type": "markdown", - "id": "e6622244", - "metadata": { - "id": "lIYdn1woOS1n" - }, - "source": [ - "`jax.lax` provides two other functions that allow branching on dynamic predicates:\n", - "\n", - "- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is\n", - " like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays\n", - " rather than as functions.\n", - "- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is\n", - " like `lax.cond`, but allows switching between any number of callable choices.\n", - "\n", - "In addition, `jax.numpy` provides several numpy-style interfaces to these functions:\n", - "\n", - "- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with\n", - " three arguments is the numpy-style wrapper of `lax.select`.\n", - "- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)\n", - " is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.\n", - "- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has\n", - " an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather\n", - " than as functions. It is implemented in terms of multiple calls to `lax.select`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xkOFAw24eOMg" - }, - "source": [ - "#### `while_loop`\n", - "\n", - "python equivalent:\n", - "```\n", - "def while_loop(cond_fun, body_fun, init_val):\n", - " val = init_val\n", - " while cond_fun(val):\n", - " val = body_fun(val)\n", - " return val\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": { - "id": "jM-D39a-c436", - "outputId": "552fe42f-4d32-4e25-c8c2-b951160a3f4e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(10, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "init_val = 0\n", - "cond_fun = lambda x: x < 10\n", - "body_fun = lambda x: x+1\n", - "lax.while_loop(cond_fun, body_fun, init_val)\n", - "# --> array(10, dtype=int32)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "apo3n3HAeQY_" - }, - "source": [ - "#### `fori_loop`\n", - "python equivalent:\n", - "```\n", - "def fori_loop(start, stop, body_fun, init_val):\n", - " val = init_val\n", - " for i in range(start, stop):\n", - " val = body_fun(i, val)\n", - " return val\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "id": "dt3tUpOmeR8u", - "outputId": "7819ca7c-1433-4d85-b542-f6159b0e8380" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(45, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "init_val = 0\n", - "start = 0\n", - "stop = 10\n", - "body_fun = lambda i,x: x+i\n", - "lax.fori_loop(start, stop, body_fun, init_val)\n", - "# --> array(45, dtype=int32)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SipXS5qiqk8e" - }, - "source": [ - "#### Summary\n", - "\n", - "$$\n", - "\\begin{array} {r|rr}\n", - "\\hline \\\n", - "\\textrm{construct}\n", - "& \\textrm{jit}\n", - "& \\textrm{grad} \\\\\n", - "\\hline \\\n", - "\\textrm{if} & ❌ & ✔ \\\\\n", - "\\textrm{for} & ✔* & ✔\\\\\n", - "\\textrm{while} & ✔* & ✔\\\\\n", - "\\textrm{lax.cond} & ✔ & ✔\\\\\n", - "\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n", - "\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n", - "\\textrm{lax.scan} & ✔ & ✔\\\\\n", - "\\hline\n", - "\\end{array}\n", - "$$\n", - "\n", - "
\n", - "\n", - "$\\ast$ = argument-value-independent loop condition - unrolls the loop\n", + "## 🔪 Control flow\n", "\n", - "
" + "Moved to {ref}`control-flow`." ] }, { @@ -2209,6 +1613,9 @@ " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", + "## 🔪 Sharp bits covered in tutorials\n", + "- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n", + "- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.\n", "\n", "## Fin.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 741fa3af063c..00955de236e7 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -31,7 +31,7 @@ JAX works great for many numerical and scientific programs, but __only if they a :id: GoK_PCxPeYcy import numpy as np -from jax import grad, jit +from jax import jit from jax import lax from jax import random import jax @@ -536,328 +536,7 @@ for subkey in subkeys: ## 🔪 Control flow -+++ {"id": "izLTvT24dAq0"} - -### ✔ Python control_flow + autodiff ✔ - -If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). - -```{code-cell} ipython3 -:id: aAx0T3F8lLtu -:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52 - -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -print(grad(f)(2.)) # ok! -print(grad(f)(4.)) # ok! -``` - -+++ {"id": "hIfPT7WMmZ2H"} - -### Python control flow + JIT - -Using control flow with `jit` is more complicated, and by default it has more constraints. - -This works: - -```{code-cell} ipython3 -:id: OZ_BJX0CplNC -:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c - -@jit -def f(x): - for i in range(3): - x = 2 * x - return x - -print(f(3)) -``` - -+++ {"id": "22RzeJ4QqAuX"} - -So does this: - -```{code-cell} ipython3 -:id: pinVnmRWp6w6 -:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422 - -@jit -def g(x): - y = 0. - for i in range(x.shape[0]): - y = y + x[i] - return y - -print(g(jnp.array([1., 2., 3.]))) -``` - -+++ {"id": "TStltU2dqf8A"} - -But this doesn't, at least by default: - -```{code-cell} ipython3 -:id: 9z38AIKclRNM -:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac -:tags: [raises-exception] - -@jit -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -# This will fail! -f(2) -``` - -+++ {"id": "pIbr4TVPqtDN"} - -__What gives!?__ - -When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation. - -For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. - -To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. - -By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. - -But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace. - -The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again: - -```{code-cell} ipython3 -:id: -Tzp0H7Bt1Sn -:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a - -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -f = jit(f, static_argnums=(0,)) - -print(f(2.)) -``` - -+++ {"id": "MHm1hIQAvBVs"} - -Here's another example, this time involving a loop: - -```{code-cell} ipython3 -:id: iwY86_JKvD6b -:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937 - -def f(x, n): - y = 0. - for i in range(n): - y = y + x[i] - return y - -f = jit(f, static_argnums=(1,)) - -f(jnp.array([2., 3., 4.]), 2) -``` - -+++ {"id": "nSPTOX8DvOeO"} - -In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation - -+++ {"id": "wWdg8LTYwCW3"} - -️⚠️ **functions with argument-__value__ dependent shapes** - -These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`. - -```{code-cell} ipython3 -:id: Tqe9uLmUI_Gv -:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883 - -def example_fun(length, val): - return jnp.ones((length,)) * val -# un-jit'd works fine -print(example_fun(5, 4)) -``` - -```{code-cell} ipython3 -:id: fOlR54XRgHpd -:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71 -:tags: [raises-exception] - -bad_example_jit = jit(example_fun) -# this will fail: -bad_example_jit(10, 4) -``` - -```{code-cell} ipython3 -:id: kH0lOD4GgFyI -:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade - -# static_argnums tells JAX to recompile on changes at these argument positions: -good_example_jit = jit(example_fun, static_argnums=(0,)) -# first compile -print(good_example_jit(10, 4)) -# recompiles -print(good_example_jit(5, 4)) -``` - -+++ {"id": "MStx_r2oKxpp"} - -`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! - -Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: - -```{code-cell} ipython3 -:id: m2ABpRd8K094 -:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3 - -@jit -def f(x): - print(x) - y = 2 * x - print(y) - return y -f(2) -``` - -+++ {"id": "uCDcWG4MnVn-"} - -### Structured control flow primitives - -There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives: - - - `lax.cond` _differentiable_ - - `lax.while_loop` __fwd-mode-differentiable__ - - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static. - - `lax.scan` _differentiable_ - -+++ {"id": "Sd9xrLMXeK3A"} - -#### `cond` -python equivalent: - -```python -def cond(pred, true_fun, false_fun, operand): - if pred: - return true_fun(operand) - else: - return false_fun(operand) -``` - -```{code-cell} ipython3 -:id: SGxz9JOWeiyH -:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3 - -from jax import lax - -operand = jnp.array([0.]) -lax.cond(True, lambda x: x+1, lambda x: x-1, operand) -# --> array([1.], dtype=float32) -lax.cond(False, lambda x: x+1, lambda x: x-1, operand) -# --> array([-1.], dtype=float32) -``` - -+++ {"id": "lIYdn1woOS1n"} - -`jax.lax` provides two other functions that allow branching on dynamic predicates: - -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is - like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays - rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is - like `lax.cond`, but allows switching between any number of callable choices. - -In addition, `jax.numpy` provides several numpy-style interfaces to these functions: - -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with - three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) - is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has - an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather - than as functions. It is implemented in terms of multiple calls to `lax.select`. - -+++ {"id": "xkOFAw24eOMg"} - -#### `while_loop` - -python equivalent: -``` -def while_loop(cond_fun, body_fun, init_val): - val = init_val - while cond_fun(val): - val = body_fun(val) - return val -``` - -```{code-cell} ipython3 -:id: jM-D39a-c436 -:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e - -init_val = 0 -cond_fun = lambda x: x < 10 -body_fun = lambda x: x+1 -lax.while_loop(cond_fun, body_fun, init_val) -# --> array(10, dtype=int32) -``` - -+++ {"id": "apo3n3HAeQY_"} - -#### `fori_loop` -python equivalent: -``` -def fori_loop(start, stop, body_fun, init_val): - val = init_val - for i in range(start, stop): - val = body_fun(i, val) - return val -``` - -```{code-cell} ipython3 -:id: dt3tUpOmeR8u -:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380 - -init_val = 0 -start = 0 -stop = 10 -body_fun = lambda i,x: x+i -lax.fori_loop(start, stop, body_fun, init_val) -# --> array(45, dtype=int32) -``` - -+++ {"id": "SipXS5qiqk8e"} - -#### Summary - -$$ -\begin{array} {r|rr} -\hline \ -\textrm{construct} -& \textrm{jit} -& \textrm{grad} \\ -\hline \ -\textrm{if} & ❌ & ✔ \\ -\textrm{for} & ✔* & ✔\\ -\textrm{while} & ✔* & ✔\\ -\textrm{lax.cond} & ✔ & ✔\\ -\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ -\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ -\textrm{lax.scan} & ✔ & ✔\\ -\hline -\end{array} -$$ - -
- -$\ast$ = argument-value-independent loop condition - unrolls the loop - -
+Moved to {ref}`control-flow`. +++ {"id": "OxLsZUyRt_kF"} @@ -1145,6 +824,9 @@ Many such cases are discussed in detail in the sections above; here we list seve ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. +## 🔪 Sharp bits covered in tutorials +- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators. +- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions. ## Fin. diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 2ff82e0431e2..fe84fc0d7f0a 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -12,6 +12,7 @@ kernelspec: name: python3 --- +(stateful-computations)= # Stateful computations diff --git a/docs/tutorials.rst b/docs/tutorials.rst index a31517155e1a..c9c2fdb1dcc7 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -16,6 +16,7 @@ Tutorials working-with-pytrees sharded-computation stateful-computations + control-flow .. toctree:: :maxdepth: 1 From d823f1720dccf1e58d37ca91111950c4384efa02 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 12 Nov 2024 11:51:55 -0800 Subject: [PATCH 004/112] jnp.logaddexp2: simplify implementation --- jax/_src/numpy/ufuncs.py | 39 ++------------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 93e116fa4b6a..a844ecbc28ac 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2630,16 +2630,6 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) -def _wrap_between(x, _a): - """Wraps `x` between `[-a, a]`.""" - a = _constant_like(x, _a) - two_a = _constant_like(x, 2 * _a) - zero = _constant_like(x, 0) - rem = lax.rem(lax.add(x, a), two_a) - rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) - return lax.sub(rem, a) - - @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. @@ -2668,33 +2658,8 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array(True, dtype=bool) """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) - return _logaddexp2(x1, x2) - - -@custom_jvp -def _logaddexp2(x1, x2): - amax = lax.max(x1, x2) - if dtypes.issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(lax._isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), - _constant_like(x1, np.log(2))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) - - -@_logaddexp2.defjvp -def _logaddexp2_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) - primal_out = logaddexp2(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out + ln2 = float(np.log(2)) + return logaddexp(x1 * ln2, x2 * ln2) / ln2 @partial(jit, inline=True) From d0f36666ff9f9ae8847b0ca645b6f1ce581907e9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Nov 2024 11:52:21 -0800 Subject: [PATCH 005/112] Update array-api-tests commit --- .github/workflows/jax-array-api.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 942034169e09..84dda34752f0 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻 + ref: '4bbe6be32c6995772f8f46a6ef050ba766581104' # Latest commit as of 2024-11-14 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} From 1f114b1cf79803462dba76bdb3a9576a6018f618 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 14 Nov 2024 15:23:26 -0500 Subject: [PATCH 006/112] Add numpy.put_along_axis. --- CHANGELOG.md | 1 + docs/jax.numpy.rst | 1 + jax/_src/numpy/lax_numpy.py | 101 +++++++++++++++++++++++++++++++++++- jax/_src/test_util.py | 25 +++++++++ jax/_src/util.py | 4 ++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 2 + tests/lax_numpy_test.py | 42 ++++++++++++++- 8 files changed, 174 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17d15c740b7e..b0b64ac71fd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.register_dataclass` now allows metadata fields to be declared inline via {func}`dataclasses.field`. See the function documentation for examples. + * Added {func}`jax.numpy.put_along_axis`. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 3922c92d98de..30553a360155 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -337,6 +337,7 @@ namespace; they are listed below. promote_types ptp put + put_along_axis quantile r_ rad2deg diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b90004e19932..3ff38f16b38a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) + ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map @@ -11433,6 +11433,105 @@ def replace(tup, val): mode="fill" if mode is None else mode, fill_value=fill_value) +_indices = indices # argument below named 'indices' shadows the function + + +def _make_along_axis_idx(shape, indices, axis): + return tuple_replace(_indices(shape, sparse=True), axis, indices) + + +@partial(jit, static_argnames=('axis', 'inplace', 'mode')) +def put_along_axis( + arr: ArrayLike, + indices: ArrayLike, + values: ArrayLike, + axis: int | None, + inplace: bool = True, + *, + mode: str | None = None, +) -> Array: + """Put values into the destination array by matching 1d index and data slices. + + JAX implementation of :func:`numpy.put_along_axis`. + + The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + arr: array into which values will be put. + indices: array of indices at which to put values. + values: array of values to put into the array. + axis: the axis along which to put values. If not specified, the array will + be flattened before indexing is applied. + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options, + see :attr:`jax.numpy.ndarray.at`. + + Returns: + A copy of ``a`` with specified entries updated. + + See Also: + - :func:`jax.numpy.put`: put elements into an array at given indices. + - :func:`jax.numpy.place`: place elements into an array via boolean mask. + - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. + - :func:`jax.numpy.take`: extract values from an array at given indices. + - :func:`jax.numpy.take_along_axis`: extract values from an array along an axis. + + Examples: + >>> from jax import numpy as jnp + >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) + >>> i = jnp.argmax(a, axis=1, keepdims=True) + >>> print(i) + [[1] + [0]] + >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) + >>> print(b) + [[10 99 20] + [99 40 50]] + """ + if inplace: + raise ValueError( + "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays" + "are immutable. Pass inplace=False to instead return an updated array.") + + util.check_arraylike("put_along_axis", arr, indices, values) + arr = asarray(arr) + indices = asarray(indices) + values = asarray(values) + + original_axis = axis + original_arr_shape = arr.shape + + if axis is None: + arr = arr.ravel() + axis = 0 + + if not arr.ndim == indices.ndim: + raise ValueError( + "put_along_axis arguments 'arr' and 'indices' must have same ndim. Got " + f"{arr.ndim=} and {indices.ndim=}." + ) + + try: + values = broadcast_to(values, indices.shape) + except ValueError: + raise ValueError( + "put_along_axis argument 'values' must be broadcastable to 'indices'. Got " + f"{values.shape=} and {indices.shape=}." + ) + + idx = _make_along_axis_idx(arr.shape, indices, axis) + result = arr.at[idx].set(values, mode=mode) + + if original_axis is None: + result = result.reshape(original_arr_shape) + + return result + + ### Indexing def _is_integer_index(idx: Any) -> bool: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 78de511d4ec4..e546ebd2a0f3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -965,6 +965,31 @@ def fn(shape, dtype): size=shape, replace=False) return fn +def rand_indices_unique_along_axis(rng): + """Sample an array of given shape containing indices up to dim (exclusive), + such that the indices are unique along the given axis. + Optionally, convert some of the resulting indices to negative indices.""" + def fn(dim, shape, axis, allow_negative=True): + batch_size = math.prod(shape[:axis] + shape[axis:][1:]) + idx = [ + rng.choice(dim, size=shape[axis], replace=False) + for _ in range(batch_size) + ] + idx = np.array(idx).reshape(batch_size, shape[axis]) + idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],)) + idx = np.moveaxis(idx, -1, axis) + + # assert that indices are unique along the given axis + count = partial(np.bincount, minlength=dim) + assert (np.apply_along_axis(count, axis, idx) <= 1).all() + + if allow_negative: + mask = rng.choice([False, True], idx.shape) + idx[mask] -= dim + return idx + + return fn + def rand_bool(rng): def generator(shape, dtype): return _cast_to_shape( diff --git a/jax/_src/util.py b/jax/_src/util.py index fce342c493ed..8dcc5eaa5804 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -453,6 +453,10 @@ def tuple_update(t, idx, val): assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] +def tuple_replace(tupl, index, item): + # unlike tuple_update, works with negative indices as well + return tupl[:index] + (item,) + tupl[index:][1:] + class HashableFunction: """Decouples function equality and hash from its identity. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9be73e96adcf..2ab0a0e3d1ab 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -202,6 +202,7 @@ printoptions as printoptions, promote_types as promote_types, put as put, + put_along_axis as put_along_axis, ravel as ravel, ravel_multi_index as ravel_multi_index, repeat as repeat, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d391abd46e13..339174136234 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, + axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ... def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7c2728af415e..a1817f528f27 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,7 +51,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning +from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -5962,6 +5962,45 @@ def np_fun(a, i, v): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) + for a_shape in nonempty_array_shapes + for axis in list(range(-len(a_shape), len(a_shape))) + for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for v_shape in [(), (1,), i_shape] + ] + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) + for a_shape in nonempty_array_shapes + for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)] + for v_shape in [(), (1,), i_shape] + ], + dtype=jtu.dtypes.all, + mode=[None, "promise_in_bounds", "clip"], + ) + def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode): + a_rng = jtu.rand_default(self.rng()) + if axis is None: + size = math.prod(a_shape) + else: + size = a_shape[axis] + i_rng = jtu.rand_indices_unique_along_axis(self.rng()) + + def args_maker(): + a = a_rng(a_shape, dtype) + i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis) + v = a_rng(v_shape, dtype) + return a, i, v + + def np_fun(a, i, v): + a_copy = a.copy() + np.put_along_axis(a_copy, i, v, axis=axis) + return a_copy + + jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + def test_rot90_error(self): with self.assertRaisesRegex( ValueError, @@ -6229,7 +6268,6 @@ def testWrappedSignaturesMatch(self): 'nditer', 'nested_iters', 'poly1d', - 'put_along_axis', 'putmask', 'real_if_close', 'recarray', From 4a3e1155b9ae4d3c941f3aacbb12c4cf43cdf059 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Nov 2024 13:07:15 -0800 Subject: [PATCH 007/112] cleanup: delete unused argument from internal reduction helper --- jax/_src/numpy/reductions.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 08f11d0cb6ad..5acad86eabef 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -82,7 +82,7 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: ReductionOp = Callable[[Any, Any], Any] -def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, +def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, *, has_identity: bool = True, preproc: Callable[[ArrayLike], ArrayLike] | None = None, bool_op: ReductionOp | None = None, @@ -215,7 +215,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: - return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, + return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric, bool_op=lax.bitwise_or, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.psum, @@ -301,7 +301,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: - return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric, + return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) @@ -386,7 +386,7 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, + return _reduction(a, "max", lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) @@ -468,7 +468,7 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None, def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, + return _reduction(a, "min", lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) @@ -548,7 +548,7 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, + return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @@ -604,7 +604,7 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, + return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @@ -664,7 +664,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: arr = lax_internal.asarray(a) init_val = np.array(-1, dtype=dtype or arr.dtype) - return _reduction(arr, name="reduce_bitwise_and", np_fun=None, op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, + return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -673,7 +673,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_bitwise_or", np_fun=None, op=lax.bitwise_or, init_val=0, preproc=_require_integer, + return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -682,7 +682,7 @@ def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_bitwise_xor", np_fun=None, op=lax.bitwise_xor, init_val=0, preproc=_require_integer, + return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -691,7 +691,7 @@ def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_logical_and", np_fun=None, op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool, + return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -700,7 +700,7 @@ def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_logical_or", np_fun=None, op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool, + return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -709,7 +709,7 @@ def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_logical_xor", np_fun=None, op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool, + return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) From c40d405e439ac782eb47949dd94362aeed29bc00 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 14 Nov 2024 14:03:58 -0800 Subject: [PATCH 008/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ecdba3f23b20e684c5e67a5ddb4f004de724f6df. PiperOrigin-RevId: 696642961 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fdb6b1607816..043b9d019eb1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2a7890387f812c17fb5f17eec961ee52ac3e059d" -XLA_SHA256 = "cfe1eebc643355f55e6422451cbd750ac6a7f096ed8d6a0605238e4d8ce6d0d1" +XLA_COMMIT = "ecdba3f23b20e684c5e67a5ddb4f004de724f6df" +XLA_SHA256 = "bfb87208d43324cdb20e03c9802360a580062b913e975b1470148dd99dfbb0d1" def repo(): tf_http_archive( From 41a0493e56154d66cda48fac20a2e6c1e7b13a50 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 12 Nov 2024 20:34:18 -0800 Subject: [PATCH 009/112] Add shard map replication rule for ffi_call. --- jax/experimental/shard_map.py | 60 +++++++++++++++++++---------------- tests/extend_test.py | 16 ++++++++++ tests/shard_map_test.py | 10 ++++++ 3 files changed, 59 insertions(+), 27 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 9391d7ddf546..4ad248c17ee2 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -51,6 +51,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) +from jax._src.extend import ffi from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, @@ -1290,30 +1291,38 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): @register_check(control_flow.conditionals.cond_p) def _cond_rule(mesh, *in_rep, branches): _, *args_rep = in_rep - true_out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) - false_out_rep = _check_rep(mesh, branches[1].jaxpr, args_rep) - if not true_out_rep == false_out_rep: - raise Exception("The true and false branches of cond produced mismatched " - f"replication types {true_out_rep} and {false_out_rep}. " - "Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return true_out_rep + out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) + for branch in branches[1:]: + out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) + if not out_rep_ == out_rep: + raise Exception("The branches of cond produced mismatched replication " + "types. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument " + "to shard_map") + return out_rep @register_rewrite(control_flow.conditionals.cond_p) def _cond_rewrite(mesh, in_rep, *args, branches): pred_rep, *args_rep = in_rep - _, true_out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) - _, false_out_rep = _replication_rewrite_nomatch(mesh, branches[1], args_rep) - out_rep = map(op.and_, true_out_rep, false_out_rep) + _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) + for branch in branches[1:]: + _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep) + if out_rep: + out_rep = map(op.and_, out_rep, out_rep_) + else: + out_rep = out_rep_ out_rep = map(partial(op.and_, pred_rep), out_rep) - branches_ = ( - _replication_rewrite_match(mesh, branches[0], args_rep, out_rep), - _replication_rewrite_match(mesh, branches[1], args_rep, out_rep), - ) + branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep) + for branch in branches) out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) return out_vals, out_rep +@register_check(control_flow.conditionals.platform_index_p) +def _platform_index_rule(mesh, *_, **__): + return set(mesh.axis_names) +register_norewrite(control_flow.conditionals.platform_index_p) + @register_rewrite(core.closed_call_p) def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep) @@ -1363,20 +1372,17 @@ def fwd_jaxpr_thunk_(*zeros): def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) - -# TODO(mattjj): make standard_check handle multiple outputs, share code @register_check(control_flow.solves.linear_solve_p) -def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs): - in_rep_ = [r for r in in_rep if r is not None] - assert in_rep - if not in_rep_[:-1] == in_rep_[1:]: - msg = ("shard_map check_rep rewrite failed. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a workaround pass the " - "check_rep=False argument to shard_map") - raise Exception(msg) - return [in_rep_[0]] * len(jaxprs.solve.out_avals) +def _linear_solve_check(mesh, *in_rep, jaxprs, **_): + out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep) + return [out_rep] * len(jaxprs.solve.out_avals) register_standard_rewrite(control_flow.solves.linear_solve_p) +@register_check(ffi.ffi_call_p) +def _ffi_call_check(mesh, *in_rep, result_avals, **_): + out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep) + return [out_rep] * len(result_avals) +register_standard_rewrite(ffi.ffi_call_p) del _check_rules[lax.tie_p] diff --git a/tests/extend_test.py b/tests/extend_test.py index b4af8bc23e16..a59c94eab5fc 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -24,6 +24,7 @@ from jax import lax import jax.extend as jex import jax.numpy as jnp +import jax.sharding as shd from jax._src import abstract_arrays from jax._src import api @@ -38,6 +39,7 @@ from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal +from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() @@ -342,6 +344,20 @@ def testInvalidResultType(self): ValueError, "All elements of result_shape_dtypes.*position 1"): jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() + @jtu.run_on_devices("gpu", "cpu") + def testShardMap(self): + mesh = jtu.create_mesh((1,), ("i",)) + x = self.rng().randn(8, 4, 5).astype(np.float32) + + @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'), + out_specs=shd.PartitionSpec('i')) + def f(x): + return ffi_call_geqrf(x) + + f(x) # eager mode doesn't crash + jax.jit(f)(x) # neither does JIT + self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) + def ffi_call_geqrf(x, **kwargs): if jtu.test_device_matches(["cpu"]): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index df24315ce110..84017bab5122 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1050,6 +1050,16 @@ def f(a): a = jnp.array([True, False]) shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + def test_switch_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + x = jnp.arange(4) + + def f(n, x, y): + return jax.lax.switch( + n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) + + shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) + def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): From a115b2cec508787ebf94061daa5d62feefd60cb3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Nov 2024 16:05:30 -0800 Subject: [PATCH 010/112] Update array-api-tests commit --- .github/workflows/jax-array-api.yml | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 84dda34752f0..763a4c04be5d 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '4bbe6be32c6995772f8f46a6ef050ba766581104' # Latest commit as of 2024-11-14 + ref: 'a3f3f376308e64f0ac15b307dfe27be945409e41' # Latest commit as of 2024-11-14 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index 6e625e708d7a..73e1c51fc8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ filterwarnings = [ # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", "default:The .* method is good for exploring strategies.*", # NOTE: this is probably not where you want to add code to suppress a From 9a0e9e55d81e8ea1b1fd2fa4eaf67074f5908bec Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 14 Nov 2024 17:31:16 -0800 Subject: [PATCH 011/112] [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in `wrap_with_sharding_op`. Also lower shardings with `Collective` axes correctly to HloSharding. PiperOrigin-RevId: 696703030 --- jax/_src/interpreters/mlir.py | 14 ++++++++++++ jax/_src/lax/lax.py | 41 ++++++++++------------------------- jax/_src/lax/parallel.py | 40 +++++++++++++++++++--------------- jax/_src/mesh.py | 19 ++++++++++++++++ jax/_src/sharding_impls.py | 7 ++++-- tests/pjit_test.py | 26 ++++++++++++++++++++++ 6 files changed, 99 insertions(+), 48 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bef465c6aa75..ee3c929b26f7 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2474,6 +2474,20 @@ def _wrap_with_spmd_op(name: str, wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape") +def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): + if sharding_proto is None: + proto = aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + else: + proto = sharding_proto + # TODO(yashkatariya): Setting all axes as unspecified should work even when + # any axes is Collective because that's what happens in partial auto shmap. + # Do that after tests for it exists. + unspecified_dims = (set(range(aval.ndim)) + if aval.sharding.mesh.are_all_axes_collective else None) + return wrap_with_sharding_op( + ctx, op, aval, proto, unspecified_dims=unspecified_dims) + + def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c45d8f5c80b2..b780aab870e9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2203,14 +2203,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): for op, in_aval in zip(ops, in_avals): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) - elif in_aval.sharding.mesh.are_all_axes_collective: - out.append(op) else: - # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains - # CompilerShardingAxis, then specify `unspecified_dims` via - # `wrap_with_sharding_op`. - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() - out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) + proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto)) return out @@ -2226,10 +2221,7 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if aval_out.sharding.mesh.are_all_axes_collective: - return [out] - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] else: return [out] @@ -2646,8 +2638,7 @@ def _integer_pow_lowering(ctx, x, *, y): out, = lowering(ctx, x, y=y) if config.sharding_in_types.value: aval_out, = ctx.avals_out - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) @@ -3029,8 +3020,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, if config.sharding_in_types.value: if sharding is not None: assert aval_out.sharding == sharding - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -3765,8 +3755,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): if config.sharding_in_types.value: if out_type is not None: assert aval_out.sharding == out_type - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp) + result = mlir.lower_sharding_under_shit(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) return [result] @@ -4231,8 +4220,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, if config.sharding_in_types.value: if sharding is not None: assert sharding == aval_out.sharding - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, @@ -4645,8 +4633,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _reshape_staging_rule( @@ -4726,8 +4713,7 @@ def _transpose_lower(ctx, x, *, permutation): permutation = [*permutation, *trailing_dims] out = hlo.transpose(x, mlir.dense_int_array(permutation)) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] transpose_p = standard_primitive( @@ -4868,8 +4854,7 @@ def _select_hlo_lowering_opaque(ctx, which, *cases): def _add_shit_to_select(ctx, op, aval_out): if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return mlir.wrap_with_sharding_op(ctx, op, aval_out, proto) + return mlir.lower_sharding_under_shit(ctx, op, aval_out) return op def _select_hlo_lowering(ctx, which, *cases): @@ -5241,8 +5226,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): with ir.InsertionPoint(reducer_region): hlo.return_([reducer(*reducer_region.arguments)]) if config.sharding_in_types.value: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, op.result, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)] return op.results mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, @@ -5941,8 +5925,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): out = mlir.iota(ctx, aval_out, dimension=dimension) if config.sharding_in_types.value: assert aval_out.sharding == sharding - proto = sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(iota_p, _iota_lower) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 3a1c1ef3bcf1..c8cea6a9df5b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,9 +24,11 @@ from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes -from jax._src import sharding_impls +from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, + NamedSharding, PartitionSpec as P) from jax._src.core import AxisName, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -635,9 +637,15 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), - arg.dtype) for arg in args] + if config.sharding_in_types.value: + out_avals = [ + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) + for arg in args + ] + else: + out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) + for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _check_axis_names(axes): @@ -673,10 +681,7 @@ def _positional_reduce(aval, arg): _replica_groups(ctx.module_context.axis_env, named_axes, axis_index_groups)) axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) def all_reduce(aval, x): if is_spmd: @@ -694,7 +699,11 @@ def all_reduce(aval, x): else: op = hlo.AllReduceOp( [x.type], [x], replica_groups=replica_groups, **other_args) - scalar_aval = core.ShapedArray((), aval.dtype) + if config.sharding_in_types.value: + scalar_aval = core.ShapedArray( + (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) + else: + scalar_aval = core.ShapedArray((), aval.dtype) scalar_type = mlir.aval_to_ir_type(scalar_aval) reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_block): @@ -778,7 +787,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): axis_context = ctx.module_context.axis_context is_manual = ( - isinstance(axis_context, sharding_impls.SPMDAxisContext) + isinstance(axis_context, SPMDAxisContext) and axis_context.manual_axes ) if is_manual: @@ -896,7 +905,7 @@ def _all_to_all_lowering( raise ValueError('Replica groups must be equally sized') is_spmd = isinstance( ctx.module_context.axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1129,10 +1138,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, x_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) if not tiled: new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, 1) @@ -1260,7 +1266,7 @@ def _reduce_scatter_lowering( axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1489,7 +1495,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: device_id = hlo.partition_id() diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 082c443fade4..6c6017c4b2b7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -107,6 +107,17 @@ class AxisTypes(enum.Enum): User = enum.auto() Collective = enum.auto() +def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: + if axis_types is None: + return {} + d = {} + for t, names in axis_types.items(): + if isinstance(names, tuple): + for n in names: + d[n] = t + else: + d[names] = t + return d _mesh_object_dict = {} # type: ignore @@ -269,6 +280,10 @@ def shape_tuple(self): def axis_sizes(self) -> tuple[int, ...]: return self.devices.shape + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @property def size(self): return math.prod(self.shape.values()) if self.devices.ndim else 0 @@ -390,6 +405,10 @@ def axis_names(self): def axis_sizes(self) -> tuple[int, ...]: return self._axis_sizes + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @functools.cached_property def size(self): return math.prod(self._axis_sizes) if self._axis_sizes else 0 diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9b847f15d86a..8957a6186339 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -137,9 +137,12 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - if self._manual_axes: + mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() + if t == mesh_lib.AxisTypes.Collective} + manual_axes = self._manual_axes.union(mesh_manual_axes) + if manual_axes: axis_names = self.mesh.axis_names - for manual_axis in self._manual_axes: + for manual_axis in manual_axes: special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL replicated_mesh_axes = [] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8a63bbe39099..7196a6335960 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5225,6 +5225,32 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_shard_map_dot(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh.are_all_axes_collective) + self.assertTrue(y.sharding.mesh.are_all_axes_collective) + allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) + z = x @ allgatherd_y + return jax.lax.psum(z, axis_name='y') + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', None))(x, y) + self.assertEqual(z.sharding.spec, P('x', None)) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From f652b6ad6aa44e586ee8989a39ca95b63205cec3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 06:03:54 -0800 Subject: [PATCH 012/112] Set __module__ attribute for objects in jax.numpy --- jax/_src/dtypes.py | 3 + jax/_src/numpy/index_tricks.py | 12 +- jax/_src/numpy/lax_numpy.py | 188 +++++++++++++++++++++++++++++++- jax/_src/numpy/polynomial.py | 14 +++ jax/_src/numpy/reductions.py | 40 ++++++- jax/_src/numpy/setops.py | 14 ++- jax/_src/numpy/ufunc_api.py | 5 +- jax/_src/numpy/ufuncs.py | 114 +++++++++++++++++++ jax/_src/numpy/vectorize.py | 5 +- tests/package_structure_test.py | 11 +- 10 files changed, 396 insertions(+), 10 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f5b0c3fd68b1..1c5e285ba08a 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -343,6 +343,7 @@ def _issubclass(a: Any, b: Any) -> bool: # TODO(jakevdp): consider whether to disallow None here. We allow it # because np.issubdtype allows it (and treats it as equivalent to float64). +@set_module('jax.numpy') def issubdtype(a: DTypeLike | ExtendedDType | None, b: DTypeLike | ExtendedDType | None) -> bool: """Returns True if first argument is a typecode lower/equal in type hierarchy. @@ -458,6 +459,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, } +@set_module('jax.numpy') def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool: """Returns a boolean indicating whether a provided dtype is of a specified kind. @@ -650,6 +652,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy "JAX's internal logic; please report it to the JAX maintainers." ) +@set_module('jax.numpy') def promote_types(a: DTypeLike, b: DTypeLike) -> DType: """Returns the type to which a binary operation should cast its arguments. diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index 90a17000cf16..ec67d7489f30 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -24,10 +24,14 @@ arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module import numpy as np +export = set_module('jax.numpy') + + __all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] @@ -87,7 +91,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: return stack(output_arr, 0) -mgrid = _Mgrid() +mgrid = export(_Mgrid()) class _Ogrid: @@ -129,7 +133,7 @@ def __getitem__( return meshgrid(*output, indexing='ij', sparse=True) -ogrid = _Ogrid() +ogrid = export(_Ogrid()) _IndexType = Union[ArrayLike, str, slice] @@ -279,7 +283,7 @@ class RClass(_AxisConcat): op_name = "r_" -r_ = RClass() +r_ = export(RClass()) class CClass(_AxisConcat): @@ -327,7 +331,7 @@ class CClass(_AxisConcat): op_name = "c_" -c_ = CClass() +c_ = export(CClass()) s_ = np.s_ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4cf37f6f7d67..4c261d11196a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,13 +68,16 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace) + ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum +export = set_module('jax.numpy') + for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: try: cuda_plugin_extension = importlib.import_module( @@ -116,6 +119,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: printoptions = np.printoptions set_printoptions = np.set_printoptions +@export def iscomplexobj(x: Any) -> bool: """Check if the input is a complex number or an array containing complex elements. @@ -327,6 +331,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: return clip(val, min_val, max_val).astype(dtype) +@export def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: """Load JAX arrays from npy files. @@ -376,6 +381,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> ### implementations of numpy functions in terms of lax +@export @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. @@ -427,6 +433,7 @@ def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise maximum of the input arrays. @@ -476,6 +483,7 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: """Return True if arg1 is equal or lower than arg2 in the type hierarchy. @@ -522,6 +530,7 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) +@export def isscalar(element: Any) -> bool: """Return True if the input is a scalar. @@ -620,6 +629,7 @@ def isscalar(element: Any) -> bool: iterable = np.iterable +@export def result_type(*args: Any) -> DType: """Return the result of applying JAX promotion rules to the inputs. @@ -663,6 +673,7 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) +@export @jit def trunc(x: ArrayLike) -> Array: """Round input to the nearest integer towards zero. @@ -739,6 +750,7 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, @@ -814,6 +826,7 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision=precision, preferred_element_type=preferred_element_type) +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, @@ -899,6 +912,7 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision=precision, preferred_element_type=preferred_element_type) +@export def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: None | Array | Sequence[ArrayLike] = None, weights: ArrayLike | None = None) -> Array: @@ -950,6 +964,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, return linspace(range[0], range[1], bins_int + 1, dtype=dtype) +@export def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Sequence[ArrayLike] | None = None, weights: ArrayLike | None = None, @@ -1031,6 +1046,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, return counts, bin_edges +@export def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1120,6 +1136,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = return hist, edges[0], edges[1] +@export def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1229,6 +1246,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim +@export def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -1307,6 +1325,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) +@export def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: """Permute the axes/dimensions of an array. @@ -1336,6 +1355,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: return lax.transpose(a, axes) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose the last two dimensions of an array. @@ -1389,6 +1409,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) +@export @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. @@ -1472,6 +1493,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) +@export def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Reverse the order of elements of an array along the given axis. @@ -1539,6 +1561,7 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) +@export def fliplr(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 1. @@ -1565,6 +1588,7 @@ def fliplr(m: ArrayLike) -> Array: return _flip(asarray(m), 1) +@export def flipud(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 0. @@ -1590,6 +1614,8 @@ def flipud(m: ArrayLike) -> Array: util.check_arraylike("flipud", m) return _flip(asarray(m), 0) + +@export @jit def iscomplex(x: ArrayLike) -> Array: """Return boolean array showing where the input is complex. @@ -1613,6 +1639,8 @@ def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) + +@export @jit def isreal(x: ArrayLike) -> Array: """Return boolean array showing where the input is real. @@ -1637,6 +1665,7 @@ def isreal(x: ArrayLike) -> Array: return lax.eq(i, _lax_const(i, 0)) +@export @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: """Return the angle of a complex valued number or array. @@ -1688,6 +1717,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result +@export @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, @@ -1800,6 +1830,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, return arr +@export @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: @@ -1862,6 +1893,8 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result + +@export @partial(jit, static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, @@ -1992,6 +2025,7 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad +@export def isrealobj(x: Any) -> bool: """Check if the input is not a complex number or an array containing complex elements. @@ -2026,6 +2060,7 @@ def isrealobj(x: Any) -> bool: return not iscomplexobj(x) +@export def reshape( a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), @@ -2129,6 +2164,7 @@ def reshape( return asarray(a).reshape(shape, order=order) +@export @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: """Flatten array into a 1-dimensional shape. @@ -2182,6 +2218,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: return reshape(a, (size(a),), order) +@export def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: """Convert multi-dimensional indices into flat indices. @@ -2273,6 +2310,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], return result +@export def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: """Convert flat indices into multi-dimensional indices. @@ -2336,6 +2374,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: for s, i in safe_zip(shape, out_indices)) +@export @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: """Return a new array with specified shape. @@ -2387,6 +2426,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) +@export def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Remove one or more length-1 axes from array @@ -2457,6 +2497,7 @@ def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: return lax.squeeze(a, axis) +@export def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: """Insert dimensions of length 1 into array @@ -2527,6 +2568,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: return lax.expand_dims(a, axis) +@export @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: """Swap two axes of an array. @@ -2574,6 +2616,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: return lax.transpose(a, list(perm)) +@export def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: """Move an array axis to a new position @@ -2639,6 +2682,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) - return lax.transpose(a, perm) +@export @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -2783,6 +2827,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f +@export def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, @@ -2865,6 +2910,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, ) -> Array | tuple[Array, ...]: ... +@export def where(condition, x=None, y=None, /, *, size=None, fill_value=None): """Select elements from two arrays based on a condition. @@ -2940,6 +2986,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): return util._where(condition, x, y) +@export def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -3007,6 +3054,7 @@ def select( return lax.select_n(*broadcast_arrays(idx, *choicelist)) +@export def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: @@ -3099,6 +3147,7 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... +@export def broadcast_shapes(*shapes): """Broadcast input shapes to a common output shape. @@ -3139,6 +3188,7 @@ def broadcast_shapes(*shapes): return lax.broadcast_shapes(*shapes) +@export def broadcast_arrays(*args: ArrayLike) -> list[Array]: """Broadcast arrays to a common shape. @@ -3178,6 +3228,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: return util._broadcast_arrays(*args) +@export def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: """Broadcast an array to a specified shape. @@ -3254,6 +3305,7 @@ def _split(op: str, ary: ArrayLike, for start, end in zip(split_indices[:-1], split_indices[1:])] +@export def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3317,6 +3369,7 @@ def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, return _split("split", ary, indices_or_sections, axis=axis) +@export def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays vertically. @@ -3351,6 +3404,7 @@ def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("vsplit", ary, indices_or_sections, axis=0) +@export def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays horizontally. @@ -3391,6 +3445,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) +@export def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays depth-wise. @@ -3432,6 +3487,7 @@ def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("dsplit", ary, indices_or_sections, axis=2) +@export def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3457,6 +3513,7 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array return _split("array_split", ary, indices_or_sections, axis=axis) +@export @jit def clip( arr: ArrayLike | None = None, @@ -3528,6 +3585,7 @@ def clip( return asarray(arr) +@export @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Round input evenly to the given number of decimals. @@ -3599,12 +3657,14 @@ def _round_float(x: ArrayLike) -> Array: return _round_float(a) +@export @partial(jit, static_argnames=('decimals',)) def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Alias of :func:`jax.numpy.round`""" return round(a, decimals, out) +@export @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. @@ -3643,6 +3703,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) +@export @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, @@ -3708,6 +3769,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, return out +@export @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -3756,6 +3818,7 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, return reductions.all(isclose(a, b, rtol, atol, equal_nan)) +@export def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: @@ -3863,6 +3926,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, return out +@export def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array: """Return indices of nonzero elements in a flattened array @@ -3908,6 +3972,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, return nonzero(ravel(a), size=size, fill_value=fill_value)[0] +@export @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: @@ -4337,6 +4402,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") +@export def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: """Add padding to an array. @@ -4493,6 +4559,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], ### Array-creation functions +@export def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: """Join arrays along a new axis. @@ -4559,6 +4626,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], return concatenate(new_arrays, axis=axis, dtype=dtype) +@export @partial(jit, static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: """Unstack an array along an axis. @@ -4599,6 +4667,8 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: ) return tuple(moveaxis(x, axis, 0)) + +@export def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: """Construct an array by repeating ``A`` along specified dimensions. @@ -4662,6 +4732,7 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, return lax.reshape(arr, shape, dimensions) +@export def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: """Join arrays along an existing axis. @@ -4725,6 +4796,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] +@export def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: """Join arrays along an existing axis. @@ -4765,6 +4837,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: return jax.numpy.concatenate(arrays, axis=axis) +@export def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Vertically stack arrays. @@ -4825,6 +4898,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) +@export def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Horizontally stack arrays. @@ -4885,6 +4959,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) +@export def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Stack arrays depth-wise. @@ -4945,6 +5020,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) +@export def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """Stack arrays column-wise. @@ -5005,6 +5081,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, axis=1) +@export def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: """Construct an array by stacking slices of choice arrays. @@ -5129,6 +5206,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: return asarray(xs), 1 +@export @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: """Create an array from a list of blocks. @@ -5212,6 +5290,7 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 1 dimension. @@ -5266,6 +5345,7 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 2 dimensions. @@ -5329,6 +5409,7 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 3 dimensions. @@ -5405,6 +5486,7 @@ def _supports_buffer_protocol(obj): return True +@export def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5597,6 +5679,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x +@export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: @@ -5662,6 +5745,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result +@export def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None, device: xc.Device | Sharding | None = None) -> Array: @@ -5743,6 +5827,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) +@export def copy(a: ArrayLike, order: str | None = None) -> Array: """Return a copy of the array. @@ -5791,6 +5876,7 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) +@export def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5833,6 +5919,7 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5875,6 +5962,7 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5924,6 +6012,7 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device +@export def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5972,6 +6061,7 @@ def full(shape: Any, fill_value: ArrayLike, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -6028,6 +6118,7 @@ def full_like(a: ArrayLike | DuckTypedArray, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of zeros. @@ -6064,6 +6155,7 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of ones. @@ -6100,6 +6192,7 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an empty array. @@ -6143,6 +6236,7 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore "with a single tuple argument for the shape?") +@export def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: """Check if two arrays are element-wise equal. @@ -6184,6 +6278,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: return reductions.all(eq) +@export def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: """Check if two arrays are element-wise equal. @@ -6224,6 +6319,7 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. +@export def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: r"""Convert a buffer into a 1-D JAX array. @@ -6271,6 +6367,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) +@export def fromfile(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromfile. @@ -6289,6 +6386,7 @@ def fromfile(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def fromiter(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromiter. @@ -6307,6 +6405,7 @@ def fromiter(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None) -> Array: """Construct a JAX array via DLPack. @@ -6367,6 +6466,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, return from_dlpack(x, device=device, copy=copy) +@export def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: """Create an array from a function applied over indices. @@ -6453,6 +6553,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) +@export def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: """Convert a string of text into 1-D JAX array. @@ -6481,6 +6582,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) +@export def eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None, @@ -6560,6 +6662,7 @@ def _eye(N: DimSize, M: DimSize | None = None, return (i + offset == j).astype(dtype) +@export def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: """Create a square identity matrix @@ -6593,6 +6696,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: return eye(n, dtype=dtype) +@export def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, step: ArrayLike | None = None, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -6760,6 +6864,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, @@ -6885,6 +6990,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return (result, delta) if retstep else result +@export def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: @@ -6970,6 +7076,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) +@export def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: """Generate geometrically-spaced values. @@ -7044,6 +7151,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) +@export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: """Construct N-dimensional grid arrays from N 1-dimensional vectors. @@ -7125,6 +7233,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output +@export @jit def i0(x: ArrayLike) -> Array: r"""Calculate modified Bessel function of first kind, zeroth order. @@ -7174,6 +7283,7 @@ def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) +@export def ix_(*args: ArrayLike) -> tuple[Array, ...]: """Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. @@ -7237,6 +7347,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: ... +@export def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: """Generate arrays of grid indices. @@ -7287,6 +7398,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, return stack(output, 0) if output else array([], dtype=dtype) +@export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: """Construct an array from repeated elements. @@ -7431,6 +7543,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) +@export @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: @@ -7490,6 +7603,7 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) +@export def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: r"""Return an array with ones on and below the diagonal and zeros elsewhere. @@ -7546,6 +7660,7 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None return lax_internal._tri(dtype, (N, M), k) +@export @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: r"""Return lower triangle of an array. @@ -7607,6 +7722,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) +@export @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: r"""Return upper triangle of an array. @@ -7672,6 +7788,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) +@export @partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -7737,6 +7854,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int return reductions.sum(a, axis=(-2, -1), dtype=dtype) +@export def mask_indices(n: int, mask_func: Callable[[ArrayLike, int], Array], k: int = 0, *, size: int | None = None) -> tuple[Array, Array]: @@ -7796,6 +7914,7 @@ def _triu_size(n, m, k): return mk * (mk + 1) // 2 + mk * (m - k - mk) +@export def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of upper triangle of an array of size ``(n, m)``. @@ -7854,6 +7973,7 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of lower triangle of an array of size ``(n, m)``. @@ -7912,6 +8032,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. @@ -7969,6 +8090,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. @@ -8026,6 +8148,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: """Return a copy of the array with the diagonal overwritten. @@ -8107,6 +8230,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n)) +@export def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a multidimensional array. @@ -8142,6 +8266,8 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: .format(ndim)) return (lax.iota(int_, n),) * ndim + +@export def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a given array. @@ -8183,6 +8309,8 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) + +@export @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: @@ -8234,6 +8362,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] +@export def diag(v: ArrayLike, k: int = 0) -> Array: """Returns the specified diagonal or constructs a diagonal array. @@ -8297,6 +8426,8 @@ def _diag(v, k): else: raise ValueError("diag input must be 1d or 2d") + +@export def diagflat(v: ArrayLike, k: int = 0) -> Array: """Return a 2-D array with the flattened input array laid out on the diagonal. @@ -8353,6 +8484,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: # TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2 +@export def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. @@ -8407,6 +8539,8 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] + +@export @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None @@ -8461,6 +8595,7 @@ def append( return concatenate([arr, values], axis=axis) +@export def delete( arr: ArrayLike, obj: ArrayLike | slice, @@ -8585,6 +8720,7 @@ def delete( return a[tuple(slice(None) for i in range(axis)) + (mask,)] +@export def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = None) -> Array: """Insert entries into an array at specified indices. @@ -8684,6 +8820,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, return out +@export def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: @@ -8761,6 +8898,7 @@ def apply_along_axis( return func(arr) +@export def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: """Apply a function repeatedly over specified axes. @@ -8819,6 +8957,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, ### Tensor contraction operations +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -8908,6 +9047,7 @@ def dot(a: ArrayLike, b: ArrayLike, *, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9031,6 +9171,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, @@ -9079,6 +9220,7 @@ def vdot( preferred_element_type=preferred_element_type) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -9134,6 +9276,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) +@export def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, @@ -9279,6 +9422,7 @@ def einsum( out_type=None, ) -> Array: ... +@export def einsum( subscripts, /, *operands, @@ -9554,6 +9698,7 @@ def einsum_path( optimize: bool | str | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... +@export def einsum_path( subscripts, /, *operands, @@ -9787,6 +9932,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9843,6 +9989,7 @@ def inner( preferred_element_type=preferred_element_type) +@export @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """Compute the outer product of two arrays. @@ -9877,6 +10024,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: return ravel(a)[:, None] * ravel(b)[None, :] +@export @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): @@ -9977,6 +10125,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) +@export @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: """Compute the Kronecker product of two input arrays. @@ -10022,6 +10171,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) +@export @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False @@ -10085,6 +10235,7 @@ def vander( ### Misc +@export def argwhere( a: ArrayLike, *, @@ -10150,6 +10301,7 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) +@export def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the maximum value of an array. @@ -10205,6 +10357,7 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the minimum value of an array. @@ -10260,6 +10413,7 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def nanargmax( a: ArrayLike, axis: int | None = None, @@ -10327,6 +10481,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export def nanargmin( a: ArrayLike, axis: int | None = None, @@ -10387,6 +10542,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, @@ -10450,6 +10606,7 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result +@export @jit def sort_complex(a: ArrayLike) -> Array: """Return a sorted copy of complex array. @@ -10487,6 +10644,7 @@ def sort_complex(a: ArrayLike) -> Array: return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) +@export @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: """Sort a sequence of keys in lexicographic order. @@ -10564,6 +10722,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1] +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, @@ -10644,6 +10803,7 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices +@export @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns a partially-sorted copy of an array. @@ -10714,6 +10874,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) +@export @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns indices that partially sort an array. @@ -10818,6 +10979,8 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a + +@export def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: """Roll the elements of an array along a specified axis. @@ -10871,6 +11034,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], return _roll_static(arr, shift, axis) +@export @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: """Roll the specified axis to a given position. @@ -10936,6 +11100,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) +@export @partial(jit, static_argnames=('axis', 'bitorder')) def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: """Pack array of bits into a uint8 array. @@ -11020,6 +11185,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar return swapaxes(packed, axis, -1) +@export @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -11111,6 +11277,7 @@ def unpackbits( return swapaxes(unpacked, axis, -1) +@export def take( a: ArrayLike, indices: ArrayLike, @@ -11268,6 +11435,7 @@ def _normalize_index(index, axis_size): return lax.select(index < 0, lax.add(index, axis_size_val), index) +@export @partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, @@ -11462,6 +11630,7 @@ def _make_along_axis_idx(shape, indices, axis): return tuple_replace(_indices(shape, sparse=True), axis, indices) +@export @partial(jit, static_argnames=('axis', 'inplace', 'mode')) def put_along_axis( arr: ArrayLike, @@ -12206,6 +12375,7 @@ def clamp_index(i: DimSize, which: str): return start, step, slice_size +@export def blackman(M: int) -> Array: """Return a Blackman window of size M. @@ -12236,6 +12406,7 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) +@export def bartlett(M: int) -> Array: """Return a Bartlett window of size M. @@ -12266,6 +12437,7 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) +@export def hamming(M: int) -> Array: """Return a Hamming window of size M. @@ -12296,6 +12468,7 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) +@export def hanning(M: int) -> Array: """Return a Hanning window of size M. @@ -12326,6 +12499,7 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) +@export def kaiser(M: int, beta: ArrayLike) -> Array: """Return a Kaiser window of size M. @@ -12368,6 +12542,8 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) + +@export @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the greatest common divisor of two arrays. @@ -12414,6 +12590,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd +@export @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the least common multiple of two arrays. @@ -12461,6 +12638,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) +@export def extract(condition: ArrayLike, arr: ArrayLike, *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: """Return the elements of an array that satisfy a condition. @@ -12522,6 +12700,7 @@ def extract(condition: ArrayLike, arr: ArrayLike, return compress(ravel(condition), ravel(arr), size=size, fill_value=fill_value) +@export def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, *, size: int | None = None, fill_value: ArrayLike = 0, out: None = None) -> Array: """Compress an array along a given axis using a boolean condition. @@ -12616,6 +12795,7 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, return moveaxis(result, 0, axis) +@export @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, @@ -12774,6 +12954,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() +@export @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: r"""Compute the Pearson correlation coefficients. @@ -12903,6 +13084,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) +@export @partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: @@ -12992,6 +13174,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', return impl(asarray(a), asarray(v), side, dtype) # type: ignore +@export @partial(jit, static_argnames=('right', 'method')) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str | None = None) -> Array: @@ -13047,6 +13230,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, ) +@export def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: @@ -13154,6 +13338,7 @@ def _tile_to_size(arr: Array, size: int) -> Array: return arr[:size] if arr.size > size else arr +@export def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: bool = True) -> Array: """Update array elements based on a mask. @@ -13229,6 +13414,7 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) +@export def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = None, *, inplace: bool = True) -> Array: """Put elements into an array at given indices. diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 10cc90575cef..19388b903e5d 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -33,6 +33,10 @@ from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, _where) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module + + +export = set_module('jax.numpy') @jit @@ -57,6 +61,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) +@export def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: r"""Returns the roots of a polynomial given the coefficients ``p``. @@ -116,6 +121,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: return _roots_with_zeros(p_arr, num_leading_zeros) +@export @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, full: bool = False, w: ArrayLike | None = None, cov: bool = False @@ -287,6 +293,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, return c +@export @jit def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. @@ -369,6 +376,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: return a +@export @partial(jit, static_argnames=['unroll']) def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. @@ -432,6 +440,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: return y +@export @jit def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the sum of the two polynomials. @@ -489,6 +498,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) +@export @partial(jit, static_argnames=('m',)) def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. @@ -557,6 +567,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array return true_divide(concatenate((p_arr, k_arr)), coeff) +@export @partial(jit, static_argnames=('m',)) def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. @@ -607,6 +618,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array: return p_arr[:-m] * coeff[::-1] +@export def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: r"""Returns the product of two polynomials. @@ -673,6 +685,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - return convolve(a1_arr, a2_arr, mode='full') +@export def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: r"""Returns the quotient and remainder of polynomial division. @@ -732,6 +745,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> return q, u_arr +@export @jit def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the difference of two polynomials. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 5acad86eabef..bc85bc3e8761 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -37,9 +37,11 @@ from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, - NumpyComplexWarning) + set_module, NumpyComplexWarning) +export = set_module('jax.numpy') + _all = builtins.all _lax_const = lax_internal._const @@ -222,6 +224,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) +@export def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: @@ -296,6 +299,7 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) + @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, @@ -307,6 +311,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None initial=initial, where_=where, promote_integers=promote_integers) +@export def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -391,6 +396,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where_=where, parallel_reduce=lax.pmax) +@export def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -473,6 +479,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where_=where, parallel_reduce=lax.pmin) +@export def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -552,6 +559,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether all array elements along a given axis evaluate to True. @@ -608,6 +616,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether any of the array elements along a given axis evaluate to True. @@ -714,6 +723,7 @@ def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) +@export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -721,6 +731,7 @@ def amin(a: ArrayLike, axis: Axis = None, out: None = None, return min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) +@export def amax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -740,6 +751,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): return size +@export def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -843,6 +855,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, * @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... +@export def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: """Compute the weighed average. @@ -953,6 +966,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, return avg +@export def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1093,6 +1107,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy return _upcast_f16(computation_dtype), np.dtype(dtype) +@export def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1185,6 +1200,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) +@export def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: r"""Return the peak-to-peak range along a given axis. @@ -1236,6 +1252,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, return lax.sub(x, y) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: @@ -1295,6 +1312,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], return out +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1377,6 +1395,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1459,6 +1478,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1542,6 +1562,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1625,6 +1646,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: @@ -1716,6 +1738,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out return td +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1818,6 +1841,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: return lax.convert_element_type(result, dtype) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1939,6 +1963,7 @@ def _cumulative_reduction( return result +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -1975,6 +2000,7 @@ def cumsum(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2010,6 +2036,7 @@ def cumprod(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2059,6 +2086,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None, fill_nan=True, fill_value=0) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2115,6 +2143,7 @@ def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, a, axis, dtype, out, promote_integers=True) +@export def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2176,6 +2205,7 @@ def cumulative_sum( return out +@export def cumulative_prod( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2239,6 +2269,7 @@ def cumulative_prod( # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2295,6 +2326,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2475,7 +2507,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2531,7 +2565,9 @@ def percentile(a: ArrayLike, q: ArrayLike, return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2591,6 +2627,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, method=method, keepdims=keepdims) +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, @@ -2642,6 +2679,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, keepdims=keepdims, method='midpoint') +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 6491a7617d8d..0d5ea905becc 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -35,10 +35,12 @@ from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan from jax._src.numpy.util import check_arraylike, promote_dtypes -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike +export = set_module('jax.numpy') + _lax_const = lax_internal._const @@ -88,6 +90,7 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: return arr, num_unique1 + num_unique2 +@export def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set difference of two 1D arrays. @@ -175,6 +178,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value) +@export def union1d(ar1: ArrayLike, ar2: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set union of two 1D arrays. @@ -278,6 +282,7 @@ def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, return where(arange(len(vals)) < num_results, vals, fill_value) +@export def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. @@ -417,6 +422,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as return vals +@export def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return_indices: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]: @@ -524,6 +530,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return int1d +@export def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: bool = False, invert: bool = False, *, method='auto') -> Array: @@ -652,6 +659,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo return ret[0] if len(ret) == 1 else ret +@export def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, axis: int | None = None, *, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None): @@ -863,6 +871,7 @@ class _UniqueInverseResult(NamedTuple): inverse_indices: Array +@export def unique_all(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueAllResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -945,6 +954,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) +@export def unique_counts(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueCountsResult: """Return unique values from x, along with counts. @@ -1005,6 +1015,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, return _UniqueCountsResult(values=values, counts=counts) +@export def unique_inverse(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueInverseResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -1070,6 +1081,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) +@export def unique_values(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Return unique values from x, along with indices, inverse indices, and counts. diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 27e2973b212b..5dbd67e62a9f 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -33,6 +33,8 @@ import numpy as np +export = set_module("jax.numpy") + _AT_INPLACE_WARNING = """\ Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. @@ -40,7 +42,7 @@ """ -@set_module('jax.numpy') +@export class ufunc: """Universal functions which operation element-by-element on arrays. @@ -586,6 +588,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: return result.reshape(*np.shape(A), *np.shape(B)) +@export def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, *, identity: Any = None) -> ufunc: """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index a844ecbc28ac..bbbce9733aa5 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -38,6 +38,10 @@ promote_shapes, _where, check_no_float0s) from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy import reductions +from jax._src.util import set_module + + +export = set_module('jax.numpy') _lax_const = lax._const @@ -75,6 +79,7 @@ def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc: return decorator +@export @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -119,18 +124,21 @@ def fabs(x: ArrayLike, /) -> Array: return lax.abs(*promote_args_inexact('fabs', x)) +@export @partial(jit, inline=True) def bitwise_invert(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_invert', x)) +@export @partial(jit, inline=True) def bitwise_not(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_not', x)) +@export @partial(jit, inline=True) def invert(x: ArrayLike, /) -> Array: """Compute the bitwise inversion of an input. @@ -223,6 +231,7 @@ def negative(x: ArrayLike, /) -> Array: return lax.neg(*promote_args('negative', x)) +@export @partial(jit, inline=True) def positive(x: ArrayLike, /) -> Array: """Return element-wise positive values of the input. @@ -271,6 +280,7 @@ def positive(x: ArrayLike, /) -> Array: return lax.asarray(*promote_args('positive', x)) +@export @partial(jit, inline=True) def sign(x: ArrayLike, /) -> Array: r"""Return an element-wise indication of sign of the input. @@ -321,6 +331,7 @@ def sign(x: ArrayLike, /) -> Array: return lax.sign(*promote_args('sign', x)) +@export @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. @@ -359,6 +370,7 @@ def floor(x: ArrayLike, /) -> Array: return lax.floor(*promote_args_inexact('floor', x)) +@export @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. @@ -397,6 +409,7 @@ def ceil(x: ArrayLike, /) -> Array: return lax.ceil(*promote_args_inexact('ceil', x)) +@export @partial(jit, inline=True) def exp(x: ArrayLike, /) -> Array: """Calculate element-wise exponential of the input. @@ -438,6 +451,7 @@ def exp(x: ArrayLike, /) -> Array: return lax.exp(*promote_args_inexact('exp', x)) +@export @partial(jit, inline=True) def log(x: ArrayLike, /) -> Array: """Calculate element-wise natural logarithm of the input. @@ -475,6 +489,7 @@ def log(x: ArrayLike, /) -> Array: return lax.log(*promote_args_inexact('log', x)) +@export @partial(jit, inline=True) def expm1(x: ArrayLike, /) -> Array: """Calculate ``exp(x)-1`` of each element of the input. @@ -519,6 +534,7 @@ def expm1(x: ArrayLike, /) -> Array: return lax.expm1(*promote_args_inexact('expm1', x)) +@export @partial(jit, inline=True) def log1p(x: ArrayLike, /) -> Array: """Calculates element-wise logarithm of one plus input, ``log(x+1)``. @@ -559,6 +575,7 @@ def log1p(x: ArrayLike, /) -> Array: return lax.log1p(*promote_args_inexact('log1p', x)) +@export @partial(jit, inline=True) def sin(x: ArrayLike, /) -> Array: """Compute a trigonometric sine of each element of input. @@ -590,6 +607,7 @@ def sin(x: ArrayLike, /) -> Array: return lax.sin(*promote_args_inexact('sin', x)) +@export @partial(jit, inline=True) def cos(x: ArrayLike, /) -> Array: """Compute a trigonometric cosine of each element of input. @@ -620,6 +638,7 @@ def cos(x: ArrayLike, /) -> Array: return lax.cos(*promote_args_inexact('cos', x)) +@export @partial(jit, inline=True) def tan(x: ArrayLike, /) -> Array: """Compute a trigonometric tangent of each element of input. @@ -650,6 +669,7 @@ def tan(x: ArrayLike, /) -> Array: return lax.tan(*promote_args_inexact('tan', x)) +@export @partial(jit, inline=True) def arcsin(x: ArrayLike, /) -> Array: r"""Compute element-wise inverse of trigonometric sine of input. @@ -691,6 +711,7 @@ def arcsin(x: ArrayLike, /) -> Array: return lax.asin(*promote_args_inexact('arcsin', x)) +@export @partial(jit, inline=True) def arccos(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric cosine of input. @@ -733,6 +754,7 @@ def arccos(x: ArrayLike, /) -> Array: return lax.acos(*promote_args_inexact('arccos', x)) +@export @partial(jit, inline=True) def arctan(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric tangent of input. @@ -773,6 +795,7 @@ def arctan(x: ArrayLike, /) -> Array: return lax.atan(*promote_args_inexact('arctan', x)) +@export @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic sine of input. @@ -827,6 +850,7 @@ def sinh(x: ArrayLike, /) -> Array: return lax.sinh(*promote_args_inexact('sinh', x)) +@export @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic cosine of input. @@ -880,6 +904,7 @@ def cosh(x: ArrayLike, /) -> Array: return lax.cosh(*promote_args_inexact('cosh', x)) +@export @partial(jit, inline=True) def arcsinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic sine of input. @@ -929,6 +954,7 @@ def arcsinh(x: ArrayLike, /) -> Array: return lax.asinh(*promote_args_inexact('arcsinh', x)) +@export @jit def arccosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic cosine of input. @@ -984,6 +1010,7 @@ def arccosh(x: ArrayLike, /) -> Array: return result +@export @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic tangent of input. @@ -1037,6 +1064,7 @@ def tanh(x: ArrayLike, /) -> Array: return lax.tanh(*promote_args_inexact('tanh', x)) +@export @partial(jit, inline=True) def arctanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic tangent of input. @@ -1085,6 +1113,7 @@ def arctanh(x: ArrayLike, /) -> Array: return lax.atanh(*promote_args_inexact('arctanh', x)) +@export @partial(jit, inline=True) def sqrt(x: ArrayLike, /) -> Array: """Calculates element-wise non-negative square root of the input array. @@ -1117,6 +1146,7 @@ def sqrt(x: ArrayLike, /) -> Array: return lax.sqrt(*promote_args_inexact('sqrt', x)) +@export @partial(jit, inline=True) def cbrt(x: ArrayLike, /) -> Array: """Calculates element-wise cube root of the input array. @@ -1144,6 +1174,7 @@ def cbrt(x: ArrayLike, /) -> Array: """ return lax.cbrt(*promote_args_inexact('cbrt', x)) + def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.add.at.""" if a.dtype == bool: @@ -1152,6 +1183,7 @@ def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: return a.at[indices].add(b).astype(bool) return a.at[indices].add(b) + @binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) def add(x: ArrayLike, y: ArrayLike, /) -> Array: """Add two arrays element-wise. @@ -1182,6 +1214,7 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.multiply.at.""" if a.dtype == bool: @@ -1191,6 +1224,7 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: else: return a.at[indices].mul(b) + @binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: """Multiply two arrays element-wise. @@ -1221,6 +1255,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) + @binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and) def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. @@ -1250,6 +1285,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) + @binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or) def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. @@ -1279,6 +1315,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) + @binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor) def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. @@ -1309,6 +1346,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) +@export @partial(jit, inline=True) def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise. @@ -1364,12 +1402,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.shift_left(*promote_args_numeric("left_shift", x, y)) +@export @partial(jit, inline=True) def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.left_shift`.""" return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) +@export @partial(jit, inline=True) def equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x == y``. @@ -1419,6 +1459,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.eq(*promote_args("equal", x, y)) +@export @partial(jit, inline=True) def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x != y``. @@ -1472,6 +1513,7 @@ def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.subtract.at.""" return a.at[indices].subtract(b) + @binary_ufunc(identity=None, at=_subtract_at) def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: """Subtract two arrays element-wise. @@ -1502,6 +1544,7 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.sub(*promote_args("subtract", x, y)) +@export @partial(jit, inline=True) def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the arctangent of x1/x2, choosing the correct quadrant. @@ -1557,6 +1600,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) +@export @partial(jit, inline=True) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1617,6 +1661,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) +@export @partial(jit, inline=True) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. @@ -1676,6 +1721,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.max(*promote_args("maximum", x, y)) +@export @partial(jit, inline=True) def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: """Calculate element-wise base ``x`` exponential of ``y``. @@ -1722,6 +1768,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.pow(*promote_args_inexact("float_power", x, y)) +@export @partial(jit, inline=True) def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise next floating point value after ``x`` towards ``y``. @@ -1749,6 +1796,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.nextafter(*promote_args_inexact("nextafter", x, y)) +@export @partial(jit, inline=True) def spacing(x: ArrayLike, /) -> Array: """Return the spacing between ``x`` and the next adjacent number. @@ -1856,6 +1904,7 @@ def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) +@export @partial(jit, inline=True) def logical_not(x: ArrayLike, /) -> Array: """Compute NOT bool(x) element-wise. @@ -1901,6 +1950,8 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], lax_op(x.real, y.real)) return lax_op(x, y) + +@export @partial(jit, inline=True) def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x >= y``. @@ -1946,6 +1997,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) +@export @partial(jit, inline=True) def greater(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x > y``. @@ -1992,6 +2044,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.gt, *promote_args("greater", x, y)) +@export @partial(jit, inline=True) def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x <= y``. @@ -2038,6 +2091,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) +@export @partial(jit, inline=True) def less(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x < y``. @@ -2083,42 +2137,58 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array: """ return _complex_comparison(lax.lt, *promote_args("less", x, y)) + # Array API aliases +@export @partial(jit, inline=True) def acos(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccos`""" return arccos(*promote_args('acos', x)) + +@export @partial(jit, inline=True) def acosh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccosh`""" return arccosh(*promote_args('acosh', x)) + +@export @partial(jit, inline=True) def asin(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsin`""" return arcsin(*promote_args('asin', x)) + +@export @partial(jit, inline=True) def asinh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsinh`""" return arcsinh(*promote_args('asinh', x)) + +@export @partial(jit, inline=True) def atan(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan`""" return arctan(*promote_args('atan', x)) + +@export @partial(jit, inline=True) def atanh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctanh`""" return arctanh(*promote_args('atanh', x)) + +@export @partial(jit, inline=True) def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" return arctan2(*promote_args('atan2', x1, x2)) + +@export @jit def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value @@ -2154,6 +2224,8 @@ def bitwise_count(x: ArrayLike, /) -> Array: # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') + +@export @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. @@ -2205,12 +2277,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_fn(x1, x2) +@export @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.right_shift`.""" return right_shift(x1, x2) +@export @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. @@ -2246,12 +2320,14 @@ def absolute(x: ArrayLike, /) -> Array: return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) +@export @partial(jit, inline=True) def abs(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.absolute`.""" return absolute(x) +@export @jit def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer @@ -2291,6 +2367,7 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) +@export @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. @@ -2330,6 +2407,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) +@export @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the division of x1 by x2 element-wise @@ -2368,11 +2446,13 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.div(x1, x2) +@export def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.true_divide`.""" return true_divide(x1, x2) +@export @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise @@ -2427,6 +2507,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _float_divmod(x1, x2)[0] +@export @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise @@ -2481,6 +2562,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod +@export def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise base ``x1`` exponential of ``x2``. @@ -2565,6 +2647,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) +@export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.power`""" return power(x1, x2) @@ -2604,6 +2687,7 @@ def _pow_int_int(x1, x2): return acc +@export @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. @@ -2630,6 +2714,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) +@export @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. @@ -2662,6 +2747,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return logaddexp(x1 * ln2, x2 * ln2) / ln2 +@export @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of ``x`` element-wise. @@ -2684,6 +2770,7 @@ def log2(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) +@export @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise @@ -2707,6 +2794,7 @@ def log10(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) +@export @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: """Calculate element-wise base-2 exponential of input. @@ -2741,6 +2829,7 @@ def exp2(x: ArrayLike, /) -> Array: return lax.exp2(x) +@export @jit def signbit(x: ArrayLike, /) -> Array: """Return the sign bit of array elements. @@ -2813,6 +2902,7 @@ def _normalize_float(x): return lax.bitcast_convert_type(x1, int_type), x2 +@export @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute x1 * 2 ** x2 @@ -2862,6 +2952,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(isinf(x1) | (x1 == 0), x1, x) +@export @jit def frexp(x: ArrayLike, /) -> tuple[Array, Array]: """Split floating point values into mantissa and twos exponent. @@ -2915,6 +3006,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Returns element-wise remainder of the division. @@ -2962,11 +3054,13 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) +@export def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.remainder`""" return remainder(x1, x2) +@export @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise floating-point modulo operation. @@ -3008,6 +3102,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.rem(*promote_args_numeric("fmod", x1, x2)) +@export @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: """Calculate element-wise square of the input array. @@ -3057,6 +3152,7 @@ def square(x: ArrayLike, /) -> Array: return lax.square(x) +@export @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: r"""Convert angles from degrees to radians. @@ -3091,6 +3187,7 @@ def deg2rad(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, np.pi / 180)) +@export @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: r"""Convert angles from radians to degrees. @@ -3126,15 +3223,19 @@ def rad2deg(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, 180 / np.pi)) +@export def degrees(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.rad2deg`""" return rad2deg(x) + +@export def radians(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.deg2rad`""" return deg2rad(x) +@export @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: """Return element-wise complex-conjugate of the input. @@ -3164,11 +3265,13 @@ def conjugate(x: ArrayLike, /) -> Array: return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) +@export def conj(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.conjugate`""" return conjugate(x) +@export @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: """Return element-wise imaginary of part of the complex argument. @@ -3200,6 +3303,7 @@ def imag(val: ArrayLike, /) -> Array: return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) +@export @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: """Return element-wise real part of the complex argument. @@ -3231,6 +3335,7 @@ def real(val: ArrayLike, /) -> Array: return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) +@export @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: """Return element-wise fractional and integral parts of the input array. @@ -3264,6 +3369,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: return x - whole, whole +@export @partial(jit, inline=True) def isfinite(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is finite. @@ -3304,6 +3410,7 @@ def isfinite(x: ArrayLike, /) -> Array: return lax.full_like(x, True, dtype=np.bool_) +@export @jit def isinf(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is infinite. @@ -3359,6 +3466,7 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) +@export def isposinf(x, /, out=None): """ Return boolean array indicating whether each element of input is positive infinite. @@ -3392,6 +3500,7 @@ def isposinf(x, /, out=None): return _isposneginf(np.inf, x, out) +@export def isneginf(x, /, out=None): """ Return boolean array indicating whether each element of input is negative infinite. @@ -3425,6 +3534,7 @@ def isneginf(x, /, out=None): return _isposneginf(-np.inf, x, out) +@export @partial(jit, inline=True) def isnan(x: ArrayLike, /) -> Array: """Returns a boolean array indicating whether each element of input is ``NaN``. @@ -3459,6 +3569,7 @@ def isnan(x: ArrayLike, /) -> Array: return lax.ne(x, x) +@export @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the heaviside step function. @@ -3508,6 +3619,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) +@export @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: r""" @@ -3556,6 +3668,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(idx_inf, _lax_const(x, np.inf), x) +@export @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: """Calculate element-wise reciprocal of the input. @@ -3589,6 +3702,7 @@ def reciprocal(x: ArrayLike, /) -> Array: return lax.integer_pow(x, -1) +@export @jit def sinc(x: ArrayLike, /) -> Array: r"""Calculate the normalized sinc function. diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e7a0e2142327..f1e6d399b97b 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,9 +23,11 @@ from jax._src import config from jax import lax from jax._src.numpy import lax_numpy as jnp -from jax._src.util import safe_map as map, safe_zip as zip +from jax._src.util import set_module, safe_map as map, safe_zip as zip +export = set_module('jax.numpy') + # See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html _DIMENSION_NAME = r'\w+' _CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME) @@ -185,6 +187,7 @@ def new_func(*args, **kwargs): return new_func, dynamic_args, dynamic_kwargs +@export def vectorize(pyfunc, *, excluded=frozenset(), signature=None): """Define a vectorized function with broadcasting. diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 71d48c2b121c..9bc8d0f6d71c 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -32,6 +32,14 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. _mod("jax.errors", exclude=["JaxRuntimeError"]), + _mod( + "jax.numpy", + exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating", + "dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo", + "flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim", + "number", "object_", "printoptions", "save", "savez", "set_printoptions", + "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] + ), _mod("jax.nn.initializers"), _mod( "jax.tree_util", @@ -46,7 +54,8 @@ def test_exported_names_match_module(self, module_name, include, exclude): if name not in include and (name.startswith('_') or name in exclude): continue obj = getattr(module, name) - if isinstance(obj, types.ModuleType): + if obj is None or isinstance(obj, (bool, int, float, complex, types.ModuleType)): + # No __module__ attribute expected. continue self.assertEqual(obj.__module__, module_name, f"{obj} has {obj.__module__=}, expected {module_name}") From 1471702adc286bcf40e87c42877d538b4d589f90 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 15 Nov 2024 06:41:14 -0800 Subject: [PATCH 013/112] [Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128) PiperOrigin-RevId: 696870258 --- jaxlib/mosaic/dialect/tpu/layout.h | 33 +++++++---- .../tpu/transforms/apply_vector_layout.cc | 56 ++++++++++++------ .../tpu/transforms/infer_vector_layout.cc | 59 +++++++++++-------- 3 files changed, 95 insertions(+), 53 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 66217858fa7d..6edad713b17a 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -39,6 +38,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "absl/log/check.h" namespace mlir::tpu { @@ -259,18 +259,23 @@ class VectorLayout { int layout_rank() const { return layout_rank(implicit_dim_); } bool operator==(const VectorLayout &other) const; - bool operator!=(const VectorLayout &other) const { - return !(*this == other); - } - - // How many tiles fit in each vector register. - int64_t tilesPerVreg(const std::array target_shape) const { - const int64_t tile_elems = tiling_[0] * tiling_[1]; - const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1]; + bool operator!=(const VectorLayout &other) const { return !(*this == other); } + + static int64_t tilesPerVreg(const std::array target_shape, + const int8_t bitwidth, + const std::array tiling) { + CHECK_NE(0, bitwidth) << "bitwidth cannot be 0"; + const int64_t tile_elems = tiling[0] * tiling[1]; + const int64_t vreg_capacity = + (32 / bitwidth) * target_shape[0] * target_shape[1]; const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems); CHECK_EQ(rem, 0); return tiles_per_vreg; } + // How many tiles fit in each vector register. + int64_t tilesPerVreg(const std::array target_shape) const { + return VectorLayout::tilesPerVreg(target_shape, bitwidth_, tiling_); + } int64_t sublanesPerTile(const std::array target_shape) const { auto [sublanes_per_tile, rem] = @@ -283,8 +288,16 @@ class VectorLayout { // // We never reuse the same vector register to store data of multiple rows, // so only the minormost dimension can increase. + static std::array vregSlice(std::array target_shape, + const int8_t bitwidth, + const std::array tiling) { + return { + tiling[0], + VectorLayout::tilesPerVreg(target_shape, bitwidth, tiling) * tiling[1]}; + } + std::array vregSlice(std::array target_shape) const { - return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]}; + return VectorLayout::vregSlice(target_shape, bitwidth_, tiling_); } template diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8792503f4636..2732b63d7638 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2554,7 +2554,10 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(res_layout.has_value()); auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank(); - if (dimension >= num_untiled_dims) { + if (res_ty.getRank() == 1 && + res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) { + tiling_dim = 1; + } else if (dimension >= num_untiled_dims) { tiling_dim = dimension - num_untiled_dims; } @@ -2576,6 +2579,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: result/input offsets mismatch."); } + if (layout.implicit_dim() != res_layout->implicit_dim()) { + return op.emitOpError( + "Not implemented: result/input implicit dim mismatch."); + } + if (i > 1) { auto curr_offsets = layout.offsets(); auto last_operand_offsets = layouts_in[i - 1]->offsets(); @@ -2611,29 +2619,47 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, if (!tiling_dim.has_value()) { out_vregs = concatenate(operand_vregs, dimension); } else { - if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { + bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 && + res_layout->implicit_dim() == + VectorLayout::ImplicitDim::kNone; + if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor || + is_rank1_with_no_implicit_dim) { return op.emitOpError("Not implemented: implicit dim"); } + if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && + res_layout->bitwidth() != 32) { + return op.emitOpError( + "Not implemented: only 32-bit bitwidth supported for SecondMinor " + "implicit dim"); + } if (res_layout->offsets()[tiling_dim.value()] != 0) { return op.emitOpError("Not implemented: result non-zero offset."); } - if (!res_layout->hasNativeTiling(ctx.target_shape)) { + if (!res_layout->hasNativeTiling(ctx.target_shape) && + res_ty.getRank() != 1) { return op.emitOpError("Not implemented: Non native tiling in concat."); } int64_t offset_at_dim = 0; { for (int i = 0; i < op.getNumOperands(); ++i) { - auto operand = op.getOperand(i); - auto const &layout = *layouts_in[i]; - - auto vty = cast(operand.getType()); - auto shape = vty.getShape(); - - auto starting_point = offset_at_dim; - auto offset_amount = - starting_point % layout.tiling()[tiling_dim.value()]; - if (offset_amount != layout.offsets()[tiling_dim.value()]) { + Value operand = op.getOperand(i); + const Layout &layout = *layouts_in[i]; + xla::Array vreg_array = operand_vregs[i]; + std::array vreg_slice = layout->vregSlice(ctx.target_shape); + std::array tiling = layout->tiling(); + + VectorType vty = cast(operand.getType()); + ArrayRef shape = vty.getShape(); + + int64_t starting_point = offset_at_dim; + int64_t offset_amount = + starting_point % vreg_slice[tiling_dim.value()]; + if (offset_amount >= tiling[tiling_dim.value()]) { + return op.emitError( + "Not implemented: Input offsets outside of the first tile"); + } + if (offset_amount != layout->offsets()[tiling_dim.value()]) { return op.emitOpError( "Not implemented: Relayout not called, unaligned dims " "concatenated without proper offsets. Ensure that " @@ -2649,10 +2675,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, auto &vreg = operand_vregs[i]; const auto &layout = layouts_in[i]; - if (layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Not implemented: implicit dim"); - } - const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; if (operand_offset != 0) { // We are offset, so we must blend with the previous vreg. diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index bf668b8ecb52..30486b6e995c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -770,14 +770,11 @@ class VectorLayoutInferer { LogicalResult infer(tpu::ConcatenateOp op) { TPU_CHECK_OP(!op.getSources().empty(), "Need at least one vector to concatenate"); - auto res_rank = op.getType().getRank(); - auto dimension = op.getDimension(); + int64_t res_rank = op.getType().getRank(); + uint32_t dimension = op.getDimension(); TPU_CHECK_OP(0 <= dimension && dimension < res_rank, "Expect a valid concatenate dimension"); - if (res_rank == 1) { - NYI("Support concatenation with 1D vectors"); - } - auto res_ty = op.getResult().getType(); + VectorType res_ty = op.getResult().getType(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); std::optional tiling_dim; @@ -790,29 +787,39 @@ class VectorLayoutInferer { if (tiling_dim.has_value()) { int64_t starting_point = 0; - auto first_layout = getLayout(op.getSources().front()); - auto op_layouts = getLayoutFromOperands(op); + Layout first_layout = getLayout(op.getSources().front()); + SmallVector op_layouts = getLayoutFromOperands(op); SmallVector in_layouts; in_layouts.reserve(op.getSources().size()); - auto native_tiling = nativeTiling(bitwidth); - + // Set implicit dim to treat 1D as (1, N) and tile it as (1, 128) + std::array tiling = + res_rank == 1 ? std::array{1L, target_shape_[1]} + : nativeTiling(bitwidth); + ImplicitDim implicit_dim = + res_rank == 1 ? ImplicitDim::kSecondMinor : ImplicitDim::kNone; + std::array vreg_slice = + VectorLayout::vregSlice(target_shape_, bitwidth, tiling); for (int i = 0; i < op.getSources().size(); ++i) { // Compute the offset per source. // Ex: for a cat of (10, 128), (10, 128) on dim 0, where the - // vreg_sice for that dim is 8, the first source starts at + // vreg_slice for that dim is 8, the first source starts at // offset 0, and overflows the vreg // by 2, so the offset for the second input is 2. - auto op_shape = + ArrayRef op_shape = cast(op.getSources()[i].getType()).getShape(); - auto offset_amount = starting_point % native_tiling[tiling_dim.value()]; - auto op_layout = op_layouts[i]; + Layout op_layout = op_layouts[i]; + int64_t offset_amount = starting_point % vreg_slice[tiling_dim.value()]; + if (offset_amount >= tiling[tiling_dim.value()]) { + return op.emitError( + "Not implemented: Input offsets outside of the first tile"); + } SmallVector in_idx{op_layout->offsets()[0].value_or(0), op_layout->offsets()[1].value_or(0)}; in_idx[tiling_dim.value()] = offset_amount; starting_point += op_shape[dimension]; in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]}, - native_tiling, ImplicitDim::kNone)); + tiling, implicit_dim)); } SmallVector res_layout_offsets( {first_layout->offsets()[0].value_or(0), @@ -821,13 +828,13 @@ class VectorLayoutInferer { // TODO(mvoz): A tiny optimization we could do here later is to // no-op setting tiling when sublane dim size is aligned to sublane // tiling. - auto res_layout = + VectorLayout res_layout = VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]}, - native_tiling, ImplicitDim::kNone); + tiling, implicit_dim); setLayout(op, in_layouts, res_layout); return success(); } else { - auto layout = getLayout(op.getSources().front()); + Layout layout = getLayout(op.getSources().front()); // When concatenating vectors with replicated offsets, we want to reset // the replicated offset to zero. Because we are not sure if the // replicated value from each vector are same. @@ -1464,11 +1471,11 @@ class VectorLayoutInferer { // unfolding, it's still a no-op, but we need to // add support in apply-vector-layout. LayoutOffsets offsets = {0, layout.offsets()[1]}; - setLayout(op, - VectorLayout(layout.bitwidth(), offsets, tiling, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), offsets, tiling, - implicit_dim)); + setLayout( + op, + VectorLayout(layout.bitwidth(), offsets, tiling, + layout.implicit_dim()), + VectorLayout(layout.bitwidth(), offsets, tiling, implicit_dim)); return success(); } sublane_tiling /= 2; @@ -1845,9 +1852,9 @@ class VectorLayoutInferer { "only 32-bit random bit generation supported"); // TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp. LayoutOffsets offsets = {0, 0}; - setOutLayout(op, VectorLayout( - kNativeBitwidth, offsets, nativeTiling(kNativeBitwidth), - ImplicitDim::kNone)); + setOutLayout( + op, VectorLayout(kNativeBitwidth, offsets, + nativeTiling(kNativeBitwidth), ImplicitDim::kNone)); return success(); } From 23e9142d2873436472991b4a96f14234f472d8df Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 15 Nov 2024 08:49:35 -0800 Subject: [PATCH 014/112] Lower threefry as an out-of-line MLIR function on TPU. On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size. --- jax/_src/prng.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index d2df5d8bbace..2256e12da1d4 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): return tuple(x) -_threefry2x32_lowering_rule = mlir.lower_fun( +# Since the unrolled lowering is large, emit it as an out-of-line function. +_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True) + multiple_results=True)) _threefry2x32_cpu_lowering_rule = mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), From 5f9428443219afb80192e16eb078368eeb7c48ef Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 12:14:55 -0800 Subject: [PATCH 015/112] Add missing functions to jax.numpy type interface --- jax/numpy/__init__.pyi | 45 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 339174136234..af7b056fcbb0 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -29,6 +29,46 @@ _Device = Device ComplexWarning: type +class ufunc: + def __init__(self, func: Callable[..., Any], /, + nin: int, nout: int, *, + name: str | None = None, + nargs: int | None = None, + identity: Any = None, + call: Callable[..., Any] | None = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None, + ): ... + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, *args: ArrayLike) -> Any: ... + def reduce(self, a: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + out: None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ... + class BinaryUfunc(Protocol): @property def nin(self) -> int: ... @@ -39,9 +79,10 @@ class BinaryUfunc(Protocol): @property def identity(self) -> builtins.bool | int | float: ... def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... - def reduce(self, arr: ArrayLike, /, *, + def reduce(self, a: ArrayLike, /, *, axis: int | None = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: builtins.bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: ... @@ -434,6 +475,8 @@ def fromfile(*args, **kwargs): ... def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = ..., **kwargs) -> Array: ... def fromiter(*args, **kwargs): ... +def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, + *, identity: Any = None) -> ufunc: ... def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... From 5f1e3f5644b6705b21b5e030d241a514c244c2c4 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 15 Nov 2024 11:26:52 -0800 Subject: [PATCH 016/112] Add an example on logical operators to the tutorial. --- docs/control-flow.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/control-flow.md b/docs/control-flow.md index 04eb3cac8d24..7cb959f3e434 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -340,6 +340,39 @@ $\ast$ = argument-value-independent loop condition - unrolls the loop `jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`. +For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar. + +```{code-cell} +def python_check_positive_even(x): + is_even = x % 2 == 0 + # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated. + return is_even and (x > 0) + +@jit +def jax_check_positive_even(x): + is_even = x % 2 == 0 + # `logical_and` does not short circuit, so `x > 0` is always evaluated. + return jnp.logical_and(is_even, x > 0) + +print(python_check_positive_even(24)) +print(jax_check_positive_even(24)) +``` + +When the JAX version with `logical_and` is applied to an array, it returns elementwise values. + +```{code-cell} +x = jnp.array([-1, 2, 5]) +print(jax_check_positive_even(x)) +``` + +Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior. + +```{code-cell} +:tags: [raises-exception] + +print(python_check_positive_even(x)) +``` + +++ {"id": "izLTvT24dAq0"} ## Python control flow + autodiff From 1780ff2964803c292dfa81adbba0f738ebafc0b9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 15 Nov 2024 13:27:42 -0800 Subject: [PATCH 017/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/195f45b7082930033f6533a160b0f8f7f1cbfb40. PiperOrigin-RevId: 696984108 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 043b9d019eb1..e7ae7fe718a6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ecdba3f23b20e684c5e67a5ddb4f004de724f6df" -XLA_SHA256 = "bfb87208d43324cdb20e03c9802360a580062b913e975b1470148dd99dfbb0d1" +XLA_COMMIT = "195f45b7082930033f6533a160b0f8f7f1cbfb40" +XLA_SHA256 = "75e77091bae789175f3de24efee9debf8835b167770490db75571bf65c27b727" def repo(): tf_http_archive( From 81cdc882aee6ab1ddb48dea6144fa52d0dc7a9c9 Mon Sep 17 00:00:00 2001 From: barnesjoseph Date: Fri, 15 Nov 2024 13:44:31 -0800 Subject: [PATCH 018/112] DOC: update main landing page style Co-authored-by: Jake VanderPlas --- docs/_static/jax-hero.svg | 118 ++++++++++++++++++ docs/_static/style.css | 255 +++++++++++++++++++++++++++++++++++++- docs/hero.html | 8 ++ docs/index.rst | 19 ++- docs/requirements.txt | 3 +- 5 files changed, 394 insertions(+), 9 deletions(-) create mode 100644 docs/_static/jax-hero.svg create mode 100644 docs/hero.html diff --git a/docs/_static/jax-hero.svg b/docs/_static/jax-hero.svg new file mode 100644 index 000000000000..04626f43eacd --- /dev/null +++ b/docs/_static/jax-hero.svg @@ -0,0 +1,118 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/style.css b/docs/_static/style.css index 2c1dfcbcbf08..32033940e8c4 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,34 +1,279 @@ @import url("theme.css"); +/* Base LP sidebar modifications */ +body:has(.hero) .sidebar-toggle, +body:has(.hero) .bd-sidebar-secondary { + display: none !important; +} + +body:has(.hero) .search-button { + display: flex !important; +} + +body:has(.hero) .primary-toggle { + display: inline-block !important; +} + +body:has(.hero) .prev-next-footer { + display: none; +} + +body:has(.hero) .bd-article-container { + max-width: unset !important; +} + +body:has(.hero) .bd-page-width { + max-width: unset !important; +} + +body:has(.hero) .bd-article { + display: flex; + flex-direction: column; + padding: 0; +} + +body:has(.hero) .bd-container { + flex-direction: column; +} + +@media (min-width: 960px) { + body:has(.hero) .bd-header-article { + justify-content: center; + } + + body:has(.hero) .header-article-items, + body:has(.hero) .bd-article > section { + max-width: 65rem !important; + align-self: center; + width: -moz-available; + width: -webkit-fill-available; + width: fill-available; + } +} + +/* Custom CSS */ + :root { --block-bg-opacity: .5; } +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) { + padding: 0; +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * { + padding-inline: 2rem !important; +} + +@media (max-width: 768px) { + .bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * { + padding-inline: 1rem !important; + } +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) h1 { + display: none; +} + .wy-side-nav-search { background-color: #fff; } +.getting-started, +.user-guides, .installation { - background-color: rgba(78, 150, 253, var(--block-bg-opacity)); + background: #3C4043; + color: white; + height: 170px; + border: none !important; + border-radius: 12px; +} + +.getting-started:hover, +.user-guides:hover, +.installation:hover { + background: #AECBFA; + color: #202124; + transform: unset !important; +} + +.getting-started .sd-card-body, +.user-guides .sd-card-body, +.installation .sd-card-body { + display: flex; + align-items: center; + justify-content: center; + font: 500 24px 'Roboto'; +} + +.getting-started .sd-card-title, +.user-guides .sd-card-title, +.installation .sd-card-title { + display: flex; + flex-direction: column; + align-items: center; + gap: 12px; +} + +.getting-started svg, +.user-guides svg, +.installation svg { + color: #8AB4F8; +} + +.getting-started:hover svg, +.user-guides:hover svg, +.installation:hover svg { + color: #3C4043; +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > .hero { + padding-inline: 2rem 0 !important; } -.getting-started { - background-color: rgba(0, 169, 154, var(--block-bg-opacity)); +.hero { + display: grid; + grid: auto-flow / 1fr .6fr; + align-items: center; + background: rgb(32,33,36); + background: linear-gradient(90deg, rgba(32,33,36,1) 0%, rgba(39,45,56,1) 100%); + position: relative; + overflow: hidden; + border-radius: 24px; } -.user-guides { - background-color: rgba(171, 0, 182, var(--block-bg-opacity)); +.hero > img { + position: absolute; + top: 0; + right: 0; + height: 100%; + background: transparent !important; +} + +.hero-left { + padding-block: 24px; + display: flex; + flex-direction: column; +} + +.hero-left img { + width: 100px; + height: auto; + position: relative; + margin-bottom: 16px; + background: transparent !important; +} + +.hero-left h2 { + font: 500 32px 'Google Sans'; + color: white; + margin-top: 0; +} + +.hero-left p { + font: 400 16px 'Roboto'; + color: white; +} + +@media (max-width: 1295px) { + .hero > img { + right: -75px; + } +} + +@media (max-width: 750px) { + .hero { + grid: auto-flow / 1fr; + } + + .hero-left { + padding-right: 2rem; + } + + .hero > img { + display: none; + } +} + +.product-offerings { + margin-block: 32px !important; +} + +.product-offerings .sd-card-title { + font: 400 24px 'Google Sans'; +} + +.color-cards { + background: #E8EAED; + color: #222832; + padding: 48px 12px 0 12px; + margin-bottom: 0 !important; + border-radius: 24px 24px 0 0; +} + +.color-cards > div { + gap: 24px 0; +} + +.color-cards + p { + background: #E8EAED; + padding: 24px 12px 48px 12px; + font-weight: 600; + color: #222832; + border-radius: 0 0 24px 24px; +} + +.color-cards + p > a { + color: #222832; +} + +.color-cards + p > a:hover, +html[data-theme="dark"] .color-cards + p > a:hover { + color: #e89217; +} + +html[data-theme="dark"] .color-cards, +html[data-theme="dark"] .hero, +html[data-theme="dark"] .color-cards + p, +html[data-theme="dark"] .color-cards + p > a { + background: #202124; + color: white; } .ecosystem-grid { font-size: smaller; } +.ecosystem-grid > div { + gap: 20px; +} + +.ecosystem-grid .sd-col { + border: 1px solid #dadce0; + border-radius: 8px; + width: calc(50% - 10px); + padding: 16px; +} + +.ecosystem-grid .sd-col > p { + display: flex; + flex-direction: column; + gap: 10px; +} + +.ecosystem-grid .sd-col > p > svg { + color: #00897B; +} + .ecosystem-grid ul { list-style-type: none; padding-inline-start: 0.5em; } +.ecosystem-grid a { + text-decoration: none; +} + div.red-background pre { background-color: rgba(244, 204, 204, var(--block-bg-opacity)); } diff --git a/docs/hero.html b/docs/hero.html new file mode 100644 index 000000000000..a2ee3b8e206f --- /dev/null +++ b/docs/hero.html @@ -0,0 +1,8 @@ +
+
+ +

High performance array computing

+

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

+
+ +
\ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 5f3bce5cf7da..ba8ebcbdd128 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,10 +1,22 @@ JAX: High performance array computing ===================================== -JAX is a Python library for accelerator-oriented array computation and program transformation, -designed for high-performance numerical computing and large-scale machine learning. +.. raw:: html + + + + +.. raw:: html + :file: hero.html .. grid:: 3 + :class-container: product-offerings :margin: 0 :padding: 0 :gutter: 0 @@ -31,6 +43,7 @@ designed for high-performance numerical computing and large-scale machine learni The same code executes on multiple backends, including CPU, GPU, & TPU .. grid:: 3 + :class-container: color-cards .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Installation :columns: 12 6 6 4 @@ -59,7 +72,7 @@ JAX itself is narrowly-scoped and focuses on efficient array operations & progra transformations. Built around JAX is an evolving ecosystem of machine learning and numerical computing tools; the following is just a small sample of what is out there: -.. grid:: 4 +.. grid:: 2 :class-container: ecosystem-grid .. grid-item:: :material-outlined:`hub;2em` **Neural networks** diff --git a/docs/requirements.txt b/docs/requirements.txt index 41d8aa6d9ee7..bfbb4e271d42 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,8 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error +pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1 -sphinx-book-theme>=1.0.1 # Older versions fail to pin pydata-sphinx-theme +sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15 sphinx-copybutton>=0.5.0 sphinx-remove-toctrees sphinx-design From 225a2a5f8bfe710e6a4aecb182d5bdd87683193b Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 15 Nov 2024 10:30:13 -0800 Subject: [PATCH 019/112] Consolidate material on PRNGs and add a short summary to Key Concepts. --- README.md | 2 +- docs/key-concepts.md | 40 +++ docs/notebooks/Common_Gotchas_in_JAX.ipynb | 307 +-------------------- docs/notebooks/Common_Gotchas_in_JAX.md | 148 +--------- docs/random-numbers.md | 20 +- jax/_src/errors.py | 4 +- 6 files changed, 65 insertions(+), 456 deletions(-) diff --git a/README.md b/README.md index 1395ae23a46e..b001a8ceeb15 100644 --- a/README.md +++ b/README.md @@ -348,7 +348,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index daab2c9fdde4..91f0c953462e 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -189,3 +189,43 @@ tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the le in a tree. You can learn more in the {ref}`working-with-pytrees` tutorial. + +(key-concepts-prngs)= +## Pseudorandom numbers + +Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: + +```{code-cell} +from jax import random + +key = random.key(43) +print(key) +``` + +The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions. +Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. + +```{code-cell} +print(random.normal(key)) +print(random.normal(key)) +``` + +**The rule of thumb is: never reuse keys (unless you want identical outputs).** + +In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: + +```{code-cell} +for i in range(3): + new_key, subkey = random.split(key) + del key # The old key is consumed by split() -- we must never use it again. + + val = random.normal(subkey) + del subkey # The subkey is consumed by normal(). + + print(f"draw {i}: {val}") + key = new_key # new_key is safe to use in the next iteration. +``` + +Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. + +For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 92c736957db6..02077d2a6b00 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -865,312 +865,9 @@ "id": "MUycRNh6e50W" }, "source": [ - "## 🔪 Random numbers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O8vvaVt3MRG2" - }, - "source": [ - "> _If all scientific papers whose results are in doubt because of bad\n", - "> `rand()`s were to disappear from library shelves, there would be a\n", - "> gap on each shelf about as big as your fist._ - Numerical Recipes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qikt9pPW9L5K" - }, - "source": [ - "### RNGs and state\n", - "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "rr9FeP41fynt", - "outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.2726690048900553\n", - "0.6304191979771206\n", - "0.6933648856441533\n" - ] - } - ], - "source": [ - "print(np.random.random())\n", - "print(np.random.random())\n", - "print(np.random.random())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ORMVVGZJgSVi" - }, - "source": [ - "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "7Pyp2ajzfPO2" - }, - "outputs": [], - "source": [ - "np.random.seed(0)\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n", - "# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n", - "# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aJIxHVXCiM6m" - }, - "source": [ - "This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "GAHaDCYafpAF" - }, - "outputs": [], - "source": [ - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n", - "\n", - "# Let's exhaust the entropy in this PRNG statevector\n", - "for i in range(311):\n", - " _ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n", + "## 🔪 Random numbers\n", "\n", - "# Next call iterates the RNG state for a new batch of fake \"entropy\".\n", - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n", - "# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N_mWnleNogps" - }, - "source": [ - "The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n", - "\n", - "The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Uvq7nV-j4vKK" - }, - "source": [ - "### JAX PRNG" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "COjzGBpO4tzL" - }, - "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", - "\n", - "The random state is described by a special array element that we call a __key__:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "yPHE7KTWgAWs", - "outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([0, 0], dtype=uint32)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = random.key(0)\n", - "key" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XjYyWYNfq0hW" - }, - "source": [ - "JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n", - "\n", - "Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "7zUdQMynoE5e", - "outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.20584226]\n", - "[0 0]\n", - "[-0.20584226]\n", - "[0 0]\n" - ] - } - ], - "source": [ - "print(random.normal(key, shape=(1,)))\n", - "print(key)\n", - "# No no no!\n", - "print(random.normal(key, shape=(1,)))\n", - "print(key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hQN9van8rJgd" - }, - "source": [ - "Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "ASj0_rSzqgGh", - "outputId": "2f13f249-85d1-47bb-d503-823eca6961aa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [0 0]\n", - " \\---SPLIT --> new key [4146024105 967050713]\n", - " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tqtFVE4MthO3" - }, - "source": [ - "We propagate the __key__ and make new __subkeys__ whenever we need a new random number:" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "jbC34XLor2Ek", - "outputId": "4059a2e2-0205-40bc-ad55-17709d538871" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [4146024105 967050713]\n", - " \\---SPLIT --> new key [2384771982 3928867769]\n", - " \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0KLYUluz3lN3" - }, - "source": [ - "We can generate more than one __subkey__ at a time:" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "lEi08PJ4tfkX", - "outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.37533438]\n", - "[0.98645043]\n", - "[0.14553197]\n" - ] - } - ], - "source": [ - "key, *subkeys = random.split(key, 4)\n", - "for subkey in subkeys:\n", - " print(random.normal(subkey, shape=(1,)))" + "JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial." ] }, { diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 00955de236e7..f35c5ead13b7 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -384,153 +384,7 @@ jnp.sum(jnp.array(x)) ## 🔪 Random numbers -+++ {"id": "O8vvaVt3MRG2"} - -> _If all scientific papers whose results are in doubt because of bad -> `rand()`s were to disappear from library shelves, there would be a -> gap on each shelf about as big as your fist._ - Numerical Recipes - -+++ {"id": "Qikt9pPW9L5K"} - -### RNGs and state -You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness: - -```{code-cell} ipython3 -:id: rr9FeP41fynt -:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 - -print(np.random.random()) -print(np.random.random()) -print(np.random.random()) -``` - -+++ {"id": "ORMVVGZJgSVi"} - -Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up. - -```{code-cell} ipython3 -:id: 7Pyp2ajzfPO2 - -np.random.seed(0) -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044, -# 2481403966, 4042607538, 337614300, ... 614 more numbers..., -# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0) -``` - -+++ {"id": "aJIxHVXCiM6m"} - -This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector: - -```{code-cell} ipython3 -:id: GAHaDCYafpAF - -_ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0) - -# Let's exhaust the entropy in this PRNG statevector -for i in range(311): - _ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0) - -# Next call iterates the RNG state for a new batch of fake "entropy". -_ = np.random.uniform() -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([1499117434, 2949980591, 2242547484, -# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0) -``` - -+++ {"id": "N_mWnleNogps"} - -The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user. - -The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. - -+++ {"id": "Uvq7nV-j4vKK"} - -### JAX PRNG - -+++ {"id": "COjzGBpO4tzL"} - -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. - -The random state is described by a special array element that we call a __key__: - -```{code-cell} ipython3 -:id: yPHE7KTWgAWs -:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 - -key = random.key(0) -key -``` - -+++ {"id": "XjYyWYNfq0hW"} - -JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! - -Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__: - -```{code-cell} ipython3 -:id: 7zUdQMynoE5e -:outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805 - -print(random.normal(key, shape=(1,))) -print(key) -# No no no! -print(random.normal(key, shape=(1,))) -print(key) -``` - -+++ {"id": "hQN9van8rJgd"} - -Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number: - -```{code-cell} ipython3 -:id: ASj0_rSzqgGh -:outputId: 2f13f249-85d1-47bb-d503-823eca6961aa - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "tqtFVE4MthO3"} - -We propagate the __key__ and make new __subkeys__ whenever we need a new random number: - -```{code-cell} ipython3 -:id: jbC34XLor2Ek -:outputId: 4059a2e2-0205-40bc-ad55-17709d538871 - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "0KLYUluz3lN3"} - -We can generate more than one __subkey__ at a time: - -```{code-cell} ipython3 -:id: lEi08PJ4tfkX -:outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01 - -key, *subkeys = random.split(key, 4) -for subkey in subkeys: - print(random.normal(subkey, shape=(1,))) -``` +JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial. +++ {"id": "rg4CpMZ8c3ri"} diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 2ad1eadb0968..00f77e3473bb 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -17,6 +17,10 @@ kernelspec: +> _If all scientific papers whose results are in doubt because of bad +> `rand()`s were to disappear from library shelves, there would be a +> gap on each shelf about as big as your fist._ - Numerical Recipes + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. @@ -35,6 +39,19 @@ import numpy as np np.random.seed(0) ``` +Repeated calls to NumPy's stateful pseudorandom number generators (PRNGs) mutate the global state and give a stream of pseudorandom numbers: + +```{code-cell} +:id: rr9FeP41fynt +:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 + +print(np.random.random()) +print(np.random.random()) +print(np.random.random()) +``` + +Underneath the hood, NumPy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this "entropy" has been used up. + You can inspect the content of the state using the following command. ```{code-cell} @@ -109,7 +126,7 @@ Further, when executing in multi-device environments, execution efficiency would ### Explicit random state -To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: +To avoid these issues, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: ```{code-cell} from jax import random @@ -137,6 +154,7 @@ Re-using the same key, even with different {mod}`~jax.random` APIs, can result i **The rule of thumb is: never reuse keys (unless you want identical outputs).** +JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: ```{code-cell} diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 590f68ac0b3b..6540fd1f5d41 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -677,7 +677,7 @@ class KeyReuseError(JAXTypeError): KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 This sort of key reuse is problematic because the JAX PRNG is stateless, and keys - must be manually split; For more information on this see `Sharp Bits: Random Numbers - `_. + must be manually split; For more information on this see `the Pseudorandom Numbers + tutorial `_. """ pass From 8525ef2b23f12affcff23b9a54d4d2515acb671f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 15 Nov 2024 17:41:42 -0800 Subject: [PATCH 020/112] [sharding_in_types] Don't emit a wsc under full manual mode to avoid increasing HLO size by a lot PiperOrigin-RevId: 697048126 --- jax/_src/interpreters/mlir.py | 20 +++++++++----------- jax/_src/mesh.py | 2 +- tests/pjit_test.py | 8 ++++---- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ee3c929b26f7..a1d326162a1c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2475,17 +2475,15 @@ def _wrap_with_spmd_op(name: str, def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): - if sharding_proto is None: - proto = aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() - else: - proto = sharding_proto - # TODO(yashkatariya): Setting all axes as unspecified should work even when - # any axes is Collective because that's what happens in partial auto shmap. - # Do that after tests for it exists. - unspecified_dims = (set(range(aval.ndim)) - if aval.sharding.mesh.are_all_axes_collective else None) - return wrap_with_sharding_op( - ctx, op, aval, proto, unspecified_dims=unspecified_dims) + # Don't emit a wsc under full manual mode to avoid increasing HLO size. + if aval.sharding.mesh._are_all_axes_collective: + return op + proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + if sharding_proto is None else sharding_proto) + # TODO(yashkatariya): Enable this + # unspecified_dims = (set(range(aval.ndim)) + # if aval.sharding.mesh._any_axis_collective else None) + return wrap_with_sharding_op(ctx, op, aval, proto) def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 6c6017c4b2b7..a2ab261fa0e9 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -426,7 +426,7 @@ def empty(self): return self.size == 0 @functools.cached_property - def are_all_axes_collective(self) -> bool: + def _are_all_axes_collective(self) -> bool: if self.axis_types is None: return False return all(t == AxisTypes.Collective for t in self.axis_types.keys()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7196a6335960..be1f9cfc267a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5208,8 +5208,8 @@ def test_shard_map_full_manual(self): arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) def g(x, y): - self.assertTrue(x.sharding.mesh.are_all_axes_collective) - self.assertTrue(y.sharding.mesh.are_all_axes_collective) + self.assertTrue(x.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.sharding.mesh._are_all_axes_collective) return x * y @jax.jit @@ -5232,8 +5232,8 @@ def test_shard_map_dot(self): arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) def g(x, y): - self.assertTrue(x.sharding.mesh.are_all_axes_collective) - self.assertTrue(y.sharding.mesh.are_all_axes_collective) + self.assertTrue(x.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.sharding.mesh._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') From 609dfac29452e4842c62168c9c9036f38976a57d Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 11 Nov 2024 22:59:24 -0800 Subject: [PATCH 021/112] Adds a flag to control proxy env checking. name typo fix. Fixes comments. --- jax/_src/clusters/cluster.py | 6 ------ jax/_src/distributed.py | 38 ++++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 2fb13fde72cf..69ef77a6421d 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -49,12 +49,6 @@ def auto_detect_unset_distributed_params(cls, initialization_timeout: int | None, ) -> tuple[str | None, int | None, int | None, Sequence[int] | None]: - - if all(p is not None for p in (coordinator_address, num_processes, - process_id, local_device_ids)): - return (coordinator_address, num_processes, process_id, - local_device_ids) - # First, we check the spec detection method because it will ignore submitted values # If if succeeds. if cluster_detection_method is not None: diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 5b9130fc0455..f80f90bde186 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -27,6 +27,13 @@ logger = logging.getLogger(__name__) +_CHECK_PROXY_ENVS = config.bool_flag( + name="jax_check_proxy_envs", + default=True, + help="Checks proxy vars in user envs and emit warnings.", +) + + class State: process_id: int = 0 num_processes: int = 1 @@ -55,16 +62,17 @@ def initialize(self, if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): local_device_ids = list(map(int, env_ids.split(","))) - (coordinator_address, num_processes, process_id, local_device_ids) = ( - clusters.ClusterEnv.auto_detect_unset_distributed_params( - coordinator_address, - num_processes, - process_id, - local_device_ids, - cluster_detection_method, - initialization_timeout, - ) - ) + if None in (coordinator_address, num_processes, process_id, local_device_ids): + (coordinator_address, num_processes, process_id, local_device_ids) = ( + clusters.ClusterEnv.auto_detect_unset_distributed_params( + coordinator_address, + num_processes, + process_id, + local_device_ids, + cluster_detection_method, + initialization_timeout, + ) + ) if coordinator_address is None: raise ValueError('coordinator_address should be defined.') @@ -92,8 +100,10 @@ def initialize(self, self.process_id = process_id - # Emit a warning about PROXY variables if they are in the user's env: - proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] + proxy_vars = [] + if _CHECK_PROXY_ENVS.value: + proxy_vars = [key for key in os.environ.keys() + if '_proxy' in key.lower()] if len(proxy_vars) > 0: vars = " ".join(proxy_vars) + ". " @@ -179,7 +189,9 @@ def initialize(coordinator_address: str | None = None, ``cluster_detection_method="mpi4py"`` to bootstrap the required arguments. Otherwise, you must provide the ``coordinator_address``, - ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + ``num_processes``, ``process_id``, and ``local_device_ids`` arguments + to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster + environment auto detection will be skipped. Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to From 626aea017b6c60b346f5e7edebfc5bbf116ff4cf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 16 Nov 2024 11:10:45 -0800 Subject: [PATCH 022/112] Deduplicate constants in StableHLO lowering. The goal of this change is to reduce the size of the generated code: we frequently built thousands of scalar 0s, for example. --- jax/_src/interpreters/mlir.py | 41 ++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a1d326162a1c..477ba6880eda 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1752,6 +1752,38 @@ def _emit_lowering_rule_as_fun(lowering_rule, return func_op +class HashableLiteral: + """Hashable wrapper of core.Literal, used for deduplicating IR constants.""" + + __slots__ = ["value"] + + value: core.Literal + + def __init__(self, value): + self.value = value + + def __hash__(self): + h = self.value.hash + return id(self.value.val) if h is None else h + + def __eq__(self, other): + if self is other: + return True + if type(self.value.val) != type(other.value.val): + return False + if self.value.aval != other.value.aval: + return False + if isinstance(self.value.val, (bool, int, float, complex)): + return self.value == other.value + if isinstance(self.value.val, (np.generic, np.ndarray)): + return np.array_equal( + self.value.val, other.value.val, + equal_nan=np.issubdtype(self.value.val.dtype, np.inexact)) + # Since the use case is constant deduplication, it's safe to return + # False in unhandled cases. + return False + + def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, name_stack: source_info_util.NameStack, tokens: TokenSet, @@ -1767,9 +1799,16 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, IR function, in the order of ctx.shape_poly_state.dim_vars. """ assert "gpu" not in ctx.platforms + cached_ir_consts: dict[HashableLiteral, IrValues] = {} + def read(v: core.Atom) -> IrValues: if type(v) is core.Literal: - return ir_constant(xla.canonicalize_dtype(v.val)) + h = HashableLiteral(v) + c = cached_ir_consts.get(h) + if c is None: + c = ir_constant(xla.canonicalize_dtype(v.val)) + cached_ir_consts[h] = c + return c else: assert isinstance(v, core.Var) return env[v] From 1d519f4ce3cd4a621b8f7e1bceab75317ed5db24 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 16 Nov 2024 13:38:23 -0800 Subject: [PATCH 023/112] Return a ndarray in shape_as_value if the shape is known to be constant. --- jax/_src/lax/lax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b780aab870e9..9c183ae93d41 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4547,6 +4547,8 @@ def shape_as_value(shape: core.Shape): """Converts a shape that may contain Poly values into a JAX value.""" if len(shape) == 0: return full((0,), np.array(0, np.int64)) + if core.is_constant_shape(shape): + return np.asarray(shape, dtype=np.int64) dims = [ expand_dims(convert_element_type(core.dimension_as_value(d), np.int64), (0,)) From 7b9914d711593dca8725d46aa1dadb2194284519 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 16 Nov 2024 13:39:24 -0800 Subject: [PATCH 024/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea. PiperOrigin-RevId: 697222155 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e7ae7fe718a6..af2fab8ed55f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "195f45b7082930033f6533a160b0f8f7f1cbfb40" -XLA_SHA256 = "75e77091bae789175f3de24efee9debf8835b167770490db75571bf65c27b727" +XLA_COMMIT = "9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea" +XLA_SHA256 = "6944ceaa425cacd30a54cca3cd6c4cb88b79f219d421fb97fa87ffbf06007143" def repo(): tf_http_archive( From 8a6c560b2562be13de5c0808db143b614db531ee Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 16 Nov 2024 14:29:20 -0800 Subject: [PATCH 025/112] Use a direct StableHLO lowering for pow. This is slightly faster than lowering via tracing, and the code is simpler also. --- jax/_src/lax/lax.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b780aab870e9..cf16c9b99935 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2574,15 +2574,12 @@ def _pow_jvp_rhs(g, ans, x, y): def _pow_lower(ctx, x, y): x_aval, y_aval = ctx.avals_in - out_aval, = ctx.avals_out - convert = mlir.lower_fun( - partial(convert_element_type, new_dtype=out_aval.dtype), False) - x_aval_ = x_aval.update(dtype=out_aval.dtype) - y_aval_ = y_aval.update(dtype=out_aval.dtype) - [x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x) - [y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y) - ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_]) - return _nary_lower_hlo(hlo.power, ctx_, x_, y_) + if x_aval.dtype != y_aval.dtype: + out_aval, = ctx.avals_out + y_aval = y_aval.update(dtype=out_aval.dtype) + y = hlo.convert(mlir.aval_to_ir_type(y_aval), y) + ctx = ctx.replace(avals_in=[x_aval, y_aval]) + return _nary_lower_hlo(hlo.power, ctx, x, y) mlir.register_lowering(pow_p, _pow_lower) def _integer_pow_dtype_rule(x, *, y): From 27bf80a50617ff38aca19e1da7c0b7599e691ef6 Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 11 Nov 2024 22:59:24 -0800 Subject: [PATCH 026/112] Adds an env that can let users provide a custom version suffix for jax dev build. fix the local version update to what Jake suggested --- jax/version.py | 6 +++++- tests/version_test.py | 30 ++++++++++++++++++++---------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/jax/version.py b/jax/version.py index c27caf979ddb..3e8a8291ec8d 100644 --- a/jax/version.py +++ b/jax/version.py @@ -60,7 +60,11 @@ def _version_from_git_tree(base_version: str) -> str | None: except: return None else: - return f"{base_version}.dev{datestring}+{commit_hash}" + version = f"{base_version}.dev{datestring}+{commit_hash}" + suffix = os.environ.get("JAX_CUSTOM_VERSION_SUFFIX", None) + if suffix: + return version + "." + suffix + return version def _get_version_for_build() -> str: diff --git a/tests/version_test.py b/tests/version_test.py index 7ce98c8588e5..51297a9716b1 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -26,11 +26,11 @@ # This is a subset of the full PEP440 pattern; for example we skip pre & post releases VERSION_PATTERN = re.compile(r""" - ^ # start of string - (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' - (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' - (?:\+(?P[a-zA-Z0-9_]+))? # optional local version; like '+g6643af3c3' - $ # end of string + ^ # start of string + (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?:\+(?P[a-zA-Z0-9_.]+))? # optional local version; like '+g6643af3c3' + $ # end of string """, re.VERBOSE) @@ -61,11 +61,12 @@ def assert_no_subprocess_call(): @contextlib.contextmanager -def assert_subprocess_call(): +def assert_subprocess_call(stdout: bytes | None = None): """Run code, asserting that subprocess.Popen *is* called at least once.""" with mock.patch("subprocess.Popen") as mock_Popen: + mock_Popen.return_value.communicate.return_value = (stdout, b"") yield - mock_Popen.assert_called() + mock_Popen.return_value.communicate.assert_called() class JaxVersionTest(unittest.TestCase): @@ -126,7 +127,7 @@ def testBuildVersionFromEnvironment(self): self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() datestring = datetime.date.today().strftime("%Y%m%d") @@ -134,19 +135,28 @@ def testBuildVersionFromEnvironment(self): self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE="1", JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() self.assertEqual(version, base_version) self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE="1", - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() self.assertEqual(version, base_version) self.assertValidVersion(version) + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None, + JAX_CUSTOM_VERSION_SUFFIX="test"): + with assert_subprocess_call(stdout=b"1731433958-1c0f1076e"): + version = jax.version._get_version_for_build() + self.assertTrue(version.startswith(f"{base_version}.dev")) + self.assertTrue(version.endswith("test")) + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3") From 742cabc54724456397dee7fd4e92411aa57f16b4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 17 Nov 2024 14:19:00 -0800 Subject: [PATCH 027/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/58ea2935b4316b48979cb47f617ae06ce9f49638. PiperOrigin-RevId: 697425145 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index af2fab8ed55f..b35f9daa2144 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea" -XLA_SHA256 = "6944ceaa425cacd30a54cca3cd6c4cb88b79f219d421fb97fa87ffbf06007143" +XLA_COMMIT = "58ea2935b4316b48979cb47f617ae06ce9f49638" +XLA_SHA256 = "669eef5690be3e1059de8429cdfbf24bf0a15a5aa6e00b9aefd7a072d839d0aa" def repo(): tf_http_archive( From ed250b89831aab2e2ed672ad05e13a7eee818396 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 17 Nov 2024 23:58:46 -0800 Subject: [PATCH 028/112] [AutoPGLE] Temporary disable pgle_test in the OSS. PiperOrigin-RevId: 697517161 --- tests/pgle_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index fa574df18f29..609ca38fd7a5 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -43,6 +43,11 @@ @jtu.pytest_mark_if_available('multiaccelerator') +# TODO(patrios): Remove this skip once b/379267258 is fixed. +@jtu.skip_under_pytest( + 'This test requires specific XLA_FLAGS. However, pytest does not reload ' + 'modules between tests. So if another test is launched before this one ' + 'necessary XLA_FLAGS will not be re-used by the XLA.') class PgleTest(jtu.JaxTestCase): _dump_exit_stack: ExitStack | None = None From ccb331707e80b16d89de6e5c9f2f89b87c1682ed Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 18 Nov 2024 08:11:04 -0800 Subject: [PATCH 029/112] Add a GPU implementation of `lax.linalg.eig`. This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. PiperOrigin-RevId: 697631402 --- CHANGELOG.md | 3 + jax/_src/config.py | 11 + jax/_src/lax/linalg.py | 133 ++++- jax/_src/numpy/linalg.py | 4 +- jaxlib/cpu/lapack_kernels.cc | 28 -- jaxlib/cpu/lapack_kernels.h | 29 ++ jaxlib/cuda/BUILD | 50 ++ jaxlib/gpu/BUILD | 3 + jaxlib/gpu/hybrid.cc | 60 +++ jaxlib/gpu/hybrid_kernels.cc | 631 ++++++++++++++++++++++++ jaxlib/gpu/hybrid_kernels.h | 55 +++ jaxlib/gpu_solver.py | 43 ++ jaxlib/jax.bzl | 1 + jaxlib/rocm/BUILD | 43 ++ jaxlib/tools/build_gpu_kernels_wheel.py | 2 + jaxlib/tools/build_wheel.py | 2 + tests/BUILD | 7 + tests/lax_numpy_test.py | 4 +- tests/linalg_test.py | 35 +- tests/magma_linalg_test.py | 125 +++++ 20 files changed, 1214 insertions(+), 55 deletions(-) create mode 100644 jaxlib/gpu/hybrid.cc create mode 100644 jaxlib/gpu/hybrid_kernels.cc create mode 100644 jaxlib/gpu/hybrid_kernels.h create mode 100644 tests/magma_linalg_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d670e43b6137..204df6a83e52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. declared inline via {func}`dataclasses.field`. See the function documentation for examples. * Added {func}`jax.numpy.put_along_axis`. + * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions + ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now + supported on GPU. See {jax-issue}`#24663` for more details. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/jax/_src/config.py b/jax/_src/config.py index 72f394dba76f..1c62f7125ee7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1963,3 +1963,14 @@ def _update_garbage_collection_guard(state, key, val): ), include_in_jit_key=True, ) + +gpu_use_magma = enum_state( + name='jax_use_magma', + enum_values=['off', 'on', 'auto'], + default='auto', + help=( + 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. ' + 'See the documentation for lax.linalg.eig for more details about how ' + 'to use this feature.' + ), +) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0e0390abc78f..62cb72c69fd7 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,16 +121,46 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, - compute_right_eigenvectors: bool = True) -> list[Array]: + compute_right_eigenvectors: bool = True, + use_magma: bool | None = None) -> list[Array]: """Eigendecomposition of a general matrix. - Nonsymmetric eigendecomposition is at present only implemented on CPU. + Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU, + the default implementation calls LAPACK directly on the host CPU, but an + experimental GPU implementation using `MAGMA `_ + is also available. The MAGMA implementation is typically slower than the + equivalent LAPACK implementation for small matrices (less than about 2048), + but it may perform better for larger matrices. + + To enable the MAGMA implementation, you must install MAGMA yourself (there + are Debian and conda-forge packages, or you can build from source). Then set + the ``use_magma`` argument to ``True``, or set the ``jax_use_magma`` + configuration variable to ``"on"`` or ``"auto"``: + + .. code-block:: python + + jax.config.update('jax_use_magma', 'on') + + JAX will try to ``dlopen`` the installed MAGMA shared library, raising an + error if it is not found. To explicitly specify the path to the MAGMA + library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full + installation path. + + If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will + be used if the library can be found, and the input matrix is sufficiently + large (>= 2048x2048). Args: x: A batch of square matrices with shape ``[..., n, n]``. compute_left_eigenvectors: If true, the left eigenvectors will be computed. compute_right_eigenvectors: If true, the right eigenvectors will be computed. + use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the + eigendecomposition is computed using MAGMA. If ``False``, the computation + is done using LAPACK on to the host CPU. If ``None`` (default), the + behavior is controlled by the ``jax_use_magma`` flag. This argument + is only used on GPU. + Returns: The eigendecomposition of ``x``, which is a tuple of the form ``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left @@ -142,7 +172,8 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, for that batch element. """ return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma) def eigh( @@ -678,12 +709,14 @@ def _symmetric_product_jax_fn(a, c, *, alpha, beta): # Asymmetric eigendecomposition -def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): +def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): return dispatch.apply_primitive( eig_p, operand, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma, ) def eig_lower(*args, **kw): @@ -692,7 +725,8 @@ def eig_lower(*args, **kw): "If your matrix is symmetric or Hermitian, you should use eigh instead.") def eig_abstract_eval(operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " @@ -716,7 +750,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors, return tuple(output) def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused operand_aval, = ctx.avals_in out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] @@ -763,18 +798,94 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, return output +def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors, + compute_right_eigenvectors, use_magma): + gpu_solver.initialize_hybrid_kernels() + dtype = x.dtype + is_real = dtype == np.float32 or dtype == np.float64 + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + batch_dims = x.shape[:-2] + n, m = x.shape[-2:] + assert n == m + num_batch_dims = len(batch_dims) + + layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims) + out_types = [ + api.ShapeDtypeStruct(batch_dims + (n,), dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims, np.int32), + ] + out_layouts = [None, layout, layout, None] + if is_real: + out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types + out_layouts = [None] + out_layouts + + magma = config.gpu_use_magma.value + if use_magma is not None: + magma = "on" if use_magma else "off" + fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout], + output_layouts=out_layouts) + *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = lax.complex(*w) + else: + assert len(w) == 1 + w = w[0] + ok = lax.eq(info, lax.zeros_like_array(info)) + ok = _broadcast_to(ok[..., None], w.shape) + w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j)) + ok = _broadcast_to(ok[..., None], x.shape) + output = [w] + if compute_left_eigenvectors: + vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j)) + output.append(vl) + if compute_right_eigenvectors: + vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j)) + output.append(vr) + return output + + +def _eig_gpu_lowering(target_name_prefix, ctx, operand, *, + compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): + if ctx.is_forward_compat(): + raise NotImplementedError( + "Export of nonsymmetric eigendecomposition on GPU is not supported " + "because of forward compatibility. The " + "'jax_export_ignore_forward_compatibility' configuration option can be " + "used to disable this check.") + rule = mlir.lower_fun(partial( + _eig_gpu_impl, target_name_prefix, + compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), multiple_results=True) + return rule(ctx, operand) + + def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors), + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), (0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors)) def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' @@ -793,6 +904,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, eig_p.def_abstract_eval(eig_abstract_eval) mlir.register_lowering(eig_p, eig_lower) mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'), + platform='cuda') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'), + platform='rocm') batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 03f864919887..76a4abff48ad 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: - This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 for 64-bit input. - - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + - At present, non-symmetric eigendecomposition is only implemented on the CPU and + GPU backends. For more details about the GPU implementation, see the + documentation for :func:`jax.lax.linalg.eig`. See also: - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 19b82a5ce149..ed815e1b1bd2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// LAPACK uses a packed representation to represent a mixture of real -// eigenvectors and complex conjugate pairs. This helper unpacks the -// representation into regular complex matrices. -template -static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag, - const T* packed, std::complex* unpacked) { - for (int j = 0; j < n;) { - if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { - // Real values in each row without imaginary part - // Second row of the imaginary part is not provided - for (int i = 0; i < n; ++i) { - unpacked[j * n + i] = {packed[j * n + i], 0.}; - } - ++j; - } else { - // Complex values where the real part is in the jth row - // and the imaginary part is in the next row (j + 1) - for (int i = 0; i < n; ++i) { - const T real_part = packed[j * n + i]; - const T imag_part = packed[(j + 1) * n + i]; - unpacked[j * n + i] = {real_part, imag_part}; - unpacked[(j + 1) * n + i] = {real_part, -imag_part}; - } - j += 2; - } - } -} - // lapack geev template diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 7d15e494fffc..cddcb1162120 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ +#include #include #include #include @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian { // lapack geev +// LAPACK uses a packed representation to represent a mixture of real +// eigenvectors and complex conjugate pairs. This helper unpacks the +// representation into regular complex matrices. +template +static void UnpackEigenvectors(Int n, const T* eigenvals_imag, + const T* packed, std::complex* unpacked) { + for (int j = 0; j < n;) { + if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { + // Real values in each row without imaginary part + // Second row of the imaginary part is not provided + for (int i = 0; i < n; ++i) { + unpacked[j * n + i] = {packed[j * n + i], 0.}; + } + ++j; + } else { + // Complex values where the real part is in the jth row + // and the imaginary part is in the next row (j + 1) + for (int i = 0; i < n; ++i) { + const T real_part = packed[j * n + i]; + const T imag_part = packed[(j + 1) * n + i]; + unpacked[j * n + i] = {real_part, imag_part}; + unpacked[(j + 1) * n + i] = {real_part, -imag_part}; + } + j += 2; + } + } +} + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 34e40d12d5be..afce2c000ecc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -476,6 +476,55 @@ pybind_extension( ], ) +cc_library( + name = "cuda_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + module_name = "_hybrid", + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_hybrid_kernels", + ":cuda_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "cuda_gpu_kernels", srcs = ["//jaxlib/gpu:gpu_kernels.cc"], @@ -633,6 +682,7 @@ py_library( name = "cuda_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_rnn", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 7d50a91cfcda..e888f6a42a9b 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -37,6 +37,9 @@ exports_files(srcs = [ "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", + "hybrid.cc", + "hybrid_kernels.cc", + "hybrid_kernels.h", "linalg.cc", "linalg_kernels.cc", "linalg_kernels.cu.cc", diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc new file mode 100644 index 000000000000..afe95a650d29 --- /dev/null +++ b/jaxlib/gpu/hybrid.cc @@ -0,0 +1,60 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/gpu/hybrid_kernels.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + +void GetLapackKernelsFromScipy() { + static bool initialized = false; // Protected by GIL + if (initialized) return; + nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); + nb::module_ cython_lapack = + nb::module_::import_("scipy.linalg.cython_lapack"); + nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); + auto lapack_ptr = [&](const char* name) { + return nb::cast(lapack_capi[name]).data(); + }; + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>(lapack_ptr("cgeev")); + AssignKernelFn>( + lapack_ptr("zgeev")); +} + +NB_MODULE(_hybrid, m) { + m.def("initialize", GetLapackKernelsFromScipy); + m.def("has_magma", []() { return MagmaLookup().FindMagmaInit().ok(); }); + m.def("registrations", []() { + nb::dict dict; + dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); + dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); + return dict; + }); +} + +} // namespace +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc new file mode 100644 index 000000000000..1ce2e547b11f --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -0,0 +1,631 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/hybrid_kernels.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace ffi = ::xla::ffi; + +// This helper class is used to define a host buffer that can be copied to and +// from a device buffer. +template +class HostBuffer { + public: + HostBuffer(std::size_t size) : size_(size) { + data_ = std::unique_ptr(new T[size]); + } + + absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T), + gpuMemcpyDeviceToHost, stream)); + } + + absl::Status CopyToDevice(gpuStream_t stream, T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T), + gpuMemcpyHostToDevice, stream)); + } + + T* get() const { return data_.get(); } + + private: + std::unique_ptr data_; + size_t size_; +}; + +// Forwarded from MAGMA for use as an input parameter. +typedef enum { + MagmaNoVec = 301, + MagmaVec = 302, +} magma_vec_t; + +// Compile time lookup of MAGMA function names depending on the data type. +template +struct always_false : std::false_type {}; + +template +struct MagmaGeev { + static_assert(always_false::value, "unsupported data type"); +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_sgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_dgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_cgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_zgeev"; +}; + +MagmaLookup::~MagmaLookup() { + if (initialized_) { + void* magma_finalize = dlsym(handle_, "magma_finalize"); + if (magma_finalize != nullptr) { + reinterpret_cast(magma_finalize)(); + } + } + if (handle_ != nullptr) { + dlclose(handle_); + } +} + +absl::StatusOr MagmaLookup::FindMagmaInit() { + void* magma_init = nullptr; + std::vector paths; + const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH"); + if (magma_lib_path != nullptr) { + paths.push_back(magma_lib_path); + } else { + paths.push_back("libmagma.so.2"); + paths.push_back("libmagma.so"); + paths.push_back(nullptr); + } + for (const auto& path : paths) { + handle_ = dlopen(path, RTLD_LAZY); + if (handle_ != nullptr) { + magma_init = dlsym(handle_, "magma_init"); + if (magma_init != nullptr) { + if (path != nullptr) { + lib_path_ = std::string(path); + } + break; + } + } + } + if (handle_ == nullptr || magma_init == nullptr) { + return absl::InternalError( + "Unable to dlopen a MAGMA shared library that defines a magma_init " + "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to " + "specify an explicit path to the library."); + } + return magma_init; +} + +absl::Status MagmaLookup::Initialize() { + if (failed_) { + return absl::InternalError("MAGMA initialization was unsuccessful."); + } + if (!initialized_) { + auto maybe_magma_init = FindMagmaInit(); + if (!maybe_magma_init.ok()) { + failed_ = true; + return maybe_magma_init.status(); + } + reinterpret_cast(maybe_magma_init.value())(); + initialized_ = true; + } + return absl::OkStatus(); +} + +absl::StatusOr MagmaLookup::Find(const char name[]) { + if (!initialized_) { + return absl::InternalError("MAGMA support has not been initialized."); + } + + auto it = symbols_.find(name); + if (it != symbols_.end()) return it->second; + + void* symbol = dlsym(handle_, name); + if (symbol == nullptr) { + if (lib_path_.has_value()) { + return absl::InternalError(absl::StrFormat( + "Unable to load the symbol '%s' from the MAGMA library at '%s'.", + name, lib_path_.value())); + + } else { + return absl::InternalError(absl::StrFormat( + "Unable to load a globally defined symbol called '%s'. Use the " + "JAX_GPU_MAGMA_PATH environment variable to specify an explicit " + "path to the library.", + name)); + } + } + + symbols_.insert({name, symbol}); + return symbol; +} + +// Lookup the MAGMA symbol for the given function name. This function only +// dlopen the MAGMA library once per process. +absl::StatusOr FindMagmaSymbol(const char name[]) { + static absl::Mutex mu; + static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu); + absl::MutexLock lock(&mu); + auto status = lookup.Initialize(); + if (!status.ok()) { + return status; + } + return lookup.Find(name); +} + +// Real-valued eigendecomposition + +template +class EigRealHost { + using Real = ffi::NativeType; + + public: + explicit EigRealHost() = default; + EigRealHost(EigRealHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi, + vl, &n_, vr, &n_, work, &lwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigRealMagma { + using Real = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*, + int, Real*, int, Real*, int, int*); + + public: + explicit EigRealMagma() = default; + EigRealMagma(EigRealMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Real query_host; + fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n, + &query_host, -1, &query_info); + return static_cast(query_host); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info); + } + + private: + int n_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto wr_host = HostBuffer(batch * cols); + auto wi_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto work_left = AllocateScratchMemory(cols * cols); + auto work_right = AllocateScratchMemory(cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), + [](auto value) { return std::isfinite(value); }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols, + wi_host.get() + i * cols, work_left.get(), work_right.get(), + work_host.get(), lwork, info_host.get() + i); + if (info_host.get()[i] == 0) { + if (left) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(), + vl_host.get() + i * cols * cols); + } + if (right) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(), + vr_host.get() + i * cols * cols); + } + } + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + wr_host.CopyToDevice(stream, wr->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + wi_host.CopyToDevice(stream, wi->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigRealDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != wr->element_type() || dataType != wi->element_type() || + ffi::ToComplex(dataType) != vl->element_type() || + ffi::ToComplex(dataType) != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig")); + FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::F32: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + case ffi::F64: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // wr + .Ret() // wi + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +// Complex-valued eigendecomposition + +template +class EigCompHost { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + public: + explicit EigCompHost() = default; + EigCompHost(EigCompHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_, + w, vl, &n_, vr, &n_, work, + &lwork, rwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigCompMagma { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*, + Complex*, int, Complex*, int, Complex*, int, Real*, int*); + + public: + explicit EigCompMagma() = default; + EigCompMagma(EigCompMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + lda_ = std::max(n_, 1); + ldvl_ = left ? n_ : 1; + ldvr_ = right ? n_ : 1; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Complex query_host; + fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr, + ldvr_, &query_host, -1, nullptr, &query_info); + return static_cast(query_host.real()); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork, + rwork, info); + } + + private: + int n_, lda_, ldvl_, ldvr_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto w_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto rwork_host = + AllocateScratchMemory(2 * cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { + return std::isfinite(z.real()) && std::isfinite(z.imag()); + }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols, + vl_host.get() + i * cols * cols, + vr_host.get() + i * cols * cols, work_host.get(), lwork, + rwork_host.get(), info_host.get() + i); + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + w_host.CopyToDevice(stream, w->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigCompDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != w->element_type() || dataType != vl->element_type() || + dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::C64: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + case ffi::C128: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, + stream, left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h new file mode 100644 index 000000000000..2890837a2bd5 --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_ +#define JAXLIB_GPU_HYBRID_KERNELS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +// The MagmaLookup class is used for dlopening the MAGMA shared library, +// initializing it, and looking up MAGMA symbols. +class MagmaLookup { + public: + explicit MagmaLookup() = default; + ~MagmaLookup(); + absl::StatusOr FindMagmaInit(); + absl::Status Initialize(); + absl::StatusOr Find(const char name[]); + + private: + bool initialized_ = false; + bool failed_ = false; + void* handle_ = nullptr; + std::optional lib_path_ = std::nullopt; + absl::flat_hash_map symbols_; +}; + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_HYBRID_KERNELS_H_ diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 03fd43e9ef89..59819f1fc914 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -56,6 +56,21 @@ xla_client.register_custom_call_target(_name, _value, platform="CUDA", api_version=api_version) +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: + try: + _cuhybrid = importlib.import_module( + f"{cuda_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _cuhybrid = None + else: + break + +if _cuhybrid: + for _name, _value in _cuhybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="CUDA", + api_version=1) + try: from .rocm import _blas as _hipblas # pytype: disable=import-error except ImportError: @@ -88,6 +103,34 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hiphybrid = importlib.import_module( + f"{rocm_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _hiphybrid = None + else: + break + +if _hiphybrid: + for _name, _value in _hiphybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM", + api_version=1) + +def initialize_hybrid_kernels(): + if _cuhybrid: + _cuhybrid.initialize() + if _hiphybrid: + _hiphybrid.initialize() + +def has_magma(): + if _cuhybrid: + return _cuhybrid.has_magma() + if _hiphybrid: + return _hiphybrid.has_magma() + return False + def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" return np.finfo(dtype).dtype diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index b5bfe733b992..2bae7ab2a203 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -66,6 +66,7 @@ _py_deps = { "filelock": ["@pypi_filelock//:pkg"], "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], + "magma": [], "matplotlib": ["@pypi_matplotlib//:pkg"], "mpmath": [], "opt_einsum": ["@pypi_opt_einsum//:pkg"], diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index c9b73a5785f1..1076f9a77bf8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -389,6 +389,48 @@ pybind_extension( ], ) +cc_library( + name = "hip_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_hybrid", + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_hybrid_kernels", + ":hip_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) + cc_library( name = "triton_kernels", srcs = ["//jaxlib/gpu:triton_kernels.cc"], @@ -456,6 +498,7 @@ py_library( name = "rocm_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_solver", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 5b3ac636303a..9a47c6ad5409 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -108,6 +108,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", @@ -144,6 +145,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 438cebca2b06..4db36fa0ea97 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -231,6 +231,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) @@ -244,6 +245,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", ], ) diff --git a/tests/BUILD b/tests/BUILD index c80f63e6d7d6..bd4312e4aa24 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -664,6 +664,13 @@ jax_multiplatform_test( }, ) +jax_multiplatform_test( + name = "magma_linalg_test", + srcs = ["magma_linalg_test.py"], + enable_backends = ["gpu"], + deps = py_deps("magma"), +) + jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a1817f528f27..7aad5634775d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1492,8 +1492,8 @@ def testTrimZerosNotOneDArray(self): def testPoly(self, a_shape, dtype, rank): if dtype in (np.float16, jnp.bfloat16, np.int16): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and not jtu.test_device_matches(["cpu"]): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d3fe8f476722..d0b109dda07e 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -34,6 +34,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -250,11 +251,11 @@ def testIssue1213(self): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] @@ -293,12 +294,12 @@ def check_left_eigenvectors(a, w, vl): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, @@ -309,15 +310,15 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, - ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + ) + @jtu.run_on_devices("cpu", "gpu") def testEigvalsGrad(self, shape, dtype): # This test sometimes fails for large matrices. I (@j-towns) suspect, but # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -329,10 +330,10 @@ def testEigvalsGrad(self, shape, dtype): shape=[(4, 4), (5, 5), (50, 50)], dtype=float_types + complex_types, ) - # TODO: enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -340,9 +341,11 @@ def testEigvals(self, shape, dtype): w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -350,8 +353,10 @@ def testEigvalsInf(self): shape=[(1, 1), (4, 4), (5, 5)], dtype=float_types + complex_types, ) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape args = rng(shape, dtype) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py new file mode 100644 index 000000000000..d2abb9fe3a0b --- /dev/null +++ b/tests/magma_linalg_test.py @@ -0,0 +1,125 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np + +from absl.testing import absltest + +import jax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import linalg as lax_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import version as jaxlib_version + +config.parse_flags_with_absl() + +float_types = jtu.dtypes.floating +complex_types = jtu.dtypes.complex + + +class MagmaLinalgTest(jtu.JaxTestCase): + + @jtu.sample_product( + shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEig(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + rng = jtu.rand_default(self.rng()) + n = shape[-1] + args_maker = lambda: [rng(shape, dtype)] + + # Norm, adjusted for dimension and type. + def norm(x): + norm = np.linalg.norm(x, axis=(-2, -1)) + return norm / ((n + 1) * jnp.finfo(dtype).eps) + + def check_right_eigenvectors(a, w, vr): + self.assertTrue( + np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) + + def check_left_eigenvectors(a, w, vl): + rank = len(a.shape) + aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) + wC = jnp.conj(w) + check_right_eigenvectors(aH, wC, vl) + + a, = args_maker() + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + w = results[0] + + if compute_left_eigenvectors: + check_left_eigenvectors(a, w, results[1]) + if compute_right_eigenvectors: + check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) + + self._CompileAndCheck(jnp.linalg.eig, args_maker, rtol=1e-3) + + @jtu.sample_product( + shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + a = jnp.full(shape, jnp.nan, dtype) + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + for result in results: + self.assertTrue(np.all(np.isnan(result))) + + def testEigMagmaConfig(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + rng = jtu.rand_default(self.rng()) + a = rng((5, 5), np.float32) + with config.gpu_use_magma("on"): + hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text() + self.assertIn('magma = "on"', hlo) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 14187399d7d555410fcaf7e18a1d2cfb4ced8987 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 08:51:54 -0800 Subject: [PATCH 030/112] Add new CI script for running Bazel GPU presubmits PiperOrigin-RevId: 697643622 --- .github/workflows/bazel_gpu_rbe.yml | 39 ++++++++++++++++++++++ ci/run_bazel_test_gpu_rbe.sh | 51 +++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 .github/workflows/bazel_gpu_rbe.yml create mode 100755 ci/run_bazel_test_gpu_rbe.sh diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml new file mode 100644 index 000000000000..a7cf645b50b3 --- /dev/null +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -0,0 +1,39 @@ +name: CI - Bazel GPU tests (RBE) + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-16"] + + runs-on: ${{ matrix.runner }} + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel GPU Tests with RBE + run: ./ci/run_bazel_test_gpu_rbe.sh \ No newline at end of file diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh new file mode 100755 index 000000000000..0c004c584300 --- /dev/null +++ b/ci/run_bazel_test_gpu_rbe.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one +# GPU apiece on RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece). +echo "Running RBE GPU tests..." + +bazel test --config=rbe_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file From e9864c69da9a9c10012d94b013f302e295434efb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 09:02:09 -0800 Subject: [PATCH 031/112] Make logaddexp and logaddexp2 into ufuncs --- jax/_src/numpy/reductions.py | 32 ++++++++++++++++++++++++++++++ jax/_src/numpy/ufuncs.py | 6 ++---- jax/numpy/__init__.pyi | 4 ++-- tests/lax_numpy_ufuncs_test.py | 36 ++++++++++++++++++++++++---------- 4 files changed, 62 insertions(+), 16 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index bc85bc3e8761..69d6843f5155 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -723,6 +723,38 @@ def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) +def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Compute log(sum(exp(a))) while avoiding precision loss.""" + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") + dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") + a_arr, = promote_dtypes_inexact(a) + pos_dims, dims = _reduction_dims(a_arr, axis) + amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf) + amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) + amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) + exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) + sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) + result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) + return result if initial is None else lax.logaddexp(initial, result) + + +def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Compute log2(sum(2 ** a)) via logsumexp.""" + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") + dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") + ln2 = float(np.log(2)) + if initial is not None: + initial *= ln2 + return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims, + where=where, initial=initial) / ln2 + + @export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index bbbce9733aa5..de8688e491ba 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2687,8 +2687,7 @@ def _pow_int_int(x1, x2): return acc -@export -@jit +@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp) def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. @@ -2714,8 +2713,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) -@export -@jit +@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp2) def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index af7b056fcbb0..b71afebe921c 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -629,8 +629,8 @@ def log(x: ArrayLike, /) -> Array: ... def log10(x: ArrayLike, /) -> Array: ... def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... -def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logaddexp: BinaryUfunc +logaddex2: BinaryUfunc logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... logical_or: BinaryUfunc diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 61c86c0a05e4..20a1a58a9dbe 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -179,13 +179,15 @@ def test_unary_ufunc_call(self, name, dtype, shape): rhs_shape=broadcast_compatible_shapes, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_bimary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + def test_binary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( @@ -218,7 +220,9 @@ def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker, tol=tol) self._CompileAndCheck(jnp_fun.outer, args_maker) @jtu.sample_product( @@ -259,7 +263,9 @@ def test_binary_ufunc_reduce(self, name, shape, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( @@ -315,7 +321,9 @@ def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): rng_where = jtu.rand_bool(self.rng()) args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] - self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( @@ -356,8 +364,10 @@ def np_fun_accumulate(x): result = np_fun.accumulate(x, axis=axis) return result if x.dtype == bool else result.astype(x.dtype) - self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) - self._CompileAndCheck(jnp_fun_accumulate, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_accumulate, args_maker, tol=tol) @jtu.sample_product( SCALAR_FUNCS, @@ -400,7 +410,9 @@ def np_fun_at(x, idx): np_fun.at(x_copy, idx) return x_copy - self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_at, args_maker) @jtu.sample_product( @@ -422,7 +434,9 @@ def np_fun_at(x, idx, y): np_fun.at(x_copy, idx, y) return x_copy - self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_at, args_maker) def test_frompyfunc_at_broadcasting(self): @@ -483,7 +497,9 @@ def np_fun_reduceat(x, i): # Numpy has different casting behavior. return np_fun.reduceat(x, i).astype(x.dtype) - self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker, tol=tol) self._CompileAndCheck(jnp_fun.reduceat, args_maker) From 6fe7b1713a5c6b2de3c7ab2fe04bc36beeb8f8f9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 18 Nov 2024 10:44:59 -0800 Subject: [PATCH 032/112] Return SingleDeviceSharding instead of GSPMDShardings when there is only 1 device during `compiled.input_shardings` call. PiperOrigin-RevId: 697683233 --- jax/_src/interpreters/pxla.py | 9 +++++---- tests/pjit_test.py | 8 ++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6c9e54441f8e..2164c1a914c9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2747,11 +2747,11 @@ def _maybe_get_and_check_out_shardings( return new_out_shardings -def finalize_out_shardings(out_shardings, device_assignment): +def finalize_shardings(shardings, device_assignment): if len(device_assignment) == 1: return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) - if isinstance(o, GSPMDSharding) else o for o in out_shardings] - return out_shardings + if isinstance(o, GSPMDSharding) else o for o in shardings] + return shardings @dataclasses.dataclass @@ -2892,7 +2892,8 @@ def from_hlo(name: str, in_shardings, out_shardings, global_in_avals, global_out_avals, intermediate_shardings, context_mesh) - out_shardings = finalize_out_shardings(out_shardings, da) + in_shardings = finalize_shardings(in_shardings, da) + out_shardings = finalize_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index be1f9cfc267a..0c1c28809062 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4624,6 +4624,14 @@ def f(x): jax.jit(f, out_shardings=s)(np.arange(8)) self.assertEqual(count[0], 1) + def test_input_shardings_single_device(self): + @jax.jit + def f(x): + return x * 2 + + ins, _ = f.lower(np.arange(8)).compile().input_shardings + self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From 5bebd0f6c40e152c90a610db80ae85e04773d088 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 18 Nov 2024 11:04:33 -0800 Subject: [PATCH 033/112] fix typo in numpy/__init__.pyi --- jax/numpy/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b71afebe921c..5d357ab1bb03 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -630,7 +630,7 @@ def log10(x: ArrayLike, /) -> Array: ... def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... logaddexp: BinaryUfunc -logaddex2: BinaryUfunc +logaddexp2: BinaryUfunc logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... logical_or: BinaryUfunc From 0ed6eaeb4a0c5ebf7679f3877b01bd7d6df29bae Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Mon, 18 Nov 2024 12:13:55 -0800 Subject: [PATCH 034/112] [SDY] fix JAX layouts tests for Shardy. PiperOrigin-RevId: 697715276 --- tests/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index bd4312e4aa24..a645a971a799 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -267,6 +267,9 @@ jax_multiplatform_test( backend_tags = { "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, + enable_configs = [ + "tpu_v3_2x2_shardy", + ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", From 461a2507f8b8e2a4da1d5de9a0c9fee98cfef245 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Nov 2024 17:04:52 -0500 Subject: [PATCH 035/112] Disable some complex function accuracy tests that fail on Mac ARM. --- tests/lax_test.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 14f453b38e7c..78bc5857acb7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4398,14 +4398,34 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'tanh': regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj') + elif name == 'arcsin': + if is_arm_cpu and platform.system() == 'Darwin': + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real') + else: + regions_with_inaccuracies.clear() + + elif name == 'arcsinh': + if is_arm_cpu and platform.system() == 'Darwin': + regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', + 'negj.imag', 'posj.imag') + else: + regions_with_inaccuracies.clear() + elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') - elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh', 'square'}: + elif name == 'log1p': + if is_arm_cpu and platform.system() == 'Darwin': + regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', + 'posj.imag') + else: + regions_with_inaccuracies.clear() + + elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan', + 'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable From f32505169fe98dc3a8f9c66ebd343bd349c8798e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Nov 2024 14:07:13 -0800 Subject: [PATCH 036/112] Filter custom dtypes by supported_dtypes in `_LazyDtypes`. The other methods of `_LazyDtypes` filter by the supported dtypes, so it's strange that this property does not. Change in preparation for landing https://github.com/jax-ml/jax/pull/23585 without breaking existing tests. PiperOrigin-RevId: 697752034 --- jax/_src/test_util.py | 14 +++++++++++--- tests/api_test.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index e546ebd2a0f3..72154fd5871d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -457,7 +457,15 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64} + np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e5m2} + elif device_under_test() == "gpu": + types = {np.bool_, np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + _dtypes.bfloat16, np.float16, np.float32, np.float64, + np.complex64, np.complex128, + _dtypes.float8_e4m3fn, _dtypes.float8_e5m2} elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: @@ -1464,10 +1472,10 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ + return self.supported([ _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]) @_cached_property def floating(self): diff --git a/tests/api_test.py b/tests/api_test.py index 49cd33ee464c..ae38f50460ab 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4625,7 +4625,7 @@ def test_cache_miss_explanations_no_source_info(self): jax.jit(operator.add)(42, 24) @parameterized.named_parameters([ - {"testcase_name": f"{dtype}", "dtype": dtype} + {"testcase_name": f"{np.dtype(dtype)}", "dtype": dtype} for dtype in jtu.dtypes.custom_floats]) def test_jit_custom_floats(self, dtype): f = lambda x: x + 1 From a60ef6e9bb19a898ab9a87e62fe4ed73c44ede24 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 18 Nov 2024 14:08:04 -0800 Subject: [PATCH 037/112] [Pallas] Increase test coverage of pl.dot. PiperOrigin-RevId: 697752355 --- tests/pallas/ops_test.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 41670137c39f..df48da776e5f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1417,17 +1417,25 @@ def f(x_ref, o_ref): np.testing.assert_allclose(f(x), expected) @parameterized.product( - size=[16, 32, 64], - dtype=["float32", "float16"], + size=[16, 32, 64, 128, 256], + dtype=[jnp.float32, jnp.float16, jnp.bfloat16], trans_x=[False, True], trans_y=[False, True], ) def test_dot(self, size, dtype, trans_x, trans_y): - if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: - self.skipTest("16-bit types are not supported on TPU") + if jtu.test_device_matches(["tpu"]): + if dtype == jnp.float16: + self.skipTest("float16 type is not supported on TPU") + if dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4): + self.skipTest("bfloat16 matmul is supported on TPUv4+") + if trans_x: + self.skipTest("Not implemented: Transposed LHS") - if jtu.test_device_matches(["tpu"]) and trans_x: - self.skipTest("Not implemented: Transposed LHS") + if jtu.test_device_matches(["gpu"]): + if dtype == jnp.bfloat16: + self.skipTest("bfloat16 type are not supported on GPU") + if size > 128: + self.skipTest("Shared memory size limit exceeded") @functools.partial( self.pallas_call, @@ -1444,7 +1452,12 @@ def dot(x_ref, y_ref, o_ref): y = random.normal(k2, (size, size), dtype=dtype) out = dot(x, y) expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected.astype(jnp.float32), + atol=0.05, + rtol=0.05, + ) @parameterized.product( size=[1, 2, 64, 129, 1021], From b3ca6c47cc30cdf6e9e3ff3de1a12b9ee1b4ad81 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 18 Nov 2024 14:21:17 -0800 Subject: [PATCH 038/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/082a7014706f67bb8a42fb1c90051bc4990f2fd3. PiperOrigin-RevId: 697756717 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b35f9daa2144..71fb2a8e9757 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "58ea2935b4316b48979cb47f617ae06ce9f49638" -XLA_SHA256 = "669eef5690be3e1059de8429cdfbf24bf0a15a5aa6e00b9aefd7a072d839d0aa" +XLA_COMMIT = "082a7014706f67bb8a42fb1c90051bc4990f2fd3" +XLA_SHA256 = "f1ca797df8e95bf13419d20520d2b783f075d80d1c5ddf1506ba427c934de849" def repo(): tf_http_archive( From d4316b5760a824bf044622073812e2f4a094a29d Mon Sep 17 00:00:00 2001 From: barnesjoseph Date: Mon, 18 Nov 2024 14:46:10 -0800 Subject: [PATCH 039/112] Adds font fallbacks --- docs/_static/style.css | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_static/style.css b/docs/_static/style.css index 32033940e8c4..d801c2a412a6 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -103,7 +103,7 @@ body:has(.hero) .bd-container { display: flex; align-items: center; justify-content: center; - font: 500 24px 'Roboto'; + font: 500 24px 'Roboto', sans-serif; } .getting-started .sd-card-title, @@ -165,13 +165,13 @@ body:has(.hero) .bd-container { } .hero-left h2 { - font: 500 32px 'Google Sans'; + font: 500 32px 'Google Sans', 'Roboto', sans-serif; color: white; margin-top: 0; } .hero-left p { - font: 400 16px 'Roboto'; + font: 400 16px 'Roboto', sans-serif; color: white; } @@ -200,7 +200,7 @@ body:has(.hero) .bd-container { } .product-offerings .sd-card-title { - font: 400 24px 'Google Sans'; + font: 400 24px 'Google Sans', 'Roboto', sans-serif; } .color-cards { From e904c177f7644f0a733501bc548d1f5b237396af Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 18 Nov 2024 15:34:55 -0800 Subject: [PATCH 040/112] Delete _normalized_spec from NamedSharding PiperOrigin-RevId: 697779844 --- jax/_src/array.py | 2 +- jax/_src/core.py | 2 +- jax/_src/sharding_impls.py | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index cf346067ea31..d8182976254e 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1035,7 +1035,7 @@ def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( self.sharding.mesh.abstract_mesh, - self.sharding._normalized_spec(self.ndim))) + self.sharding.spec._normalized_spec(self.ndim))) else: return self.aval api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array diff --git a/jax/_src/core.py b/jax/_src/core.py index a1fcdac65df0..cbf3282fb2cc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1704,7 +1704,7 @@ def _get_abstract_sharding(val): if (config.sharding_in_types.value and hasattr(val, 'sharding') and isinstance(val.sharding, NamedSharding)): return NamedSharding(val.sharding.mesh.abstract_mesh, - val.sharding._normalized_spec(val.ndim)) + val.sharding.spec._normalized_spec(val.ndim)) return None def primal_dtype_to_tangent_dtype(primal_dtype): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 8957a6186339..dc4171eec146 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -363,9 +363,6 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) - def _normalized_spec(self, ndim: int) -> PartitionSpec: - return self.spec._normalized_spec(ndim) - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) From 2c68569af05d54a66a3c47b28bc1c20317f9e560 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 18 Nov 2024 16:20:21 -0800 Subject: [PATCH 041/112] Fix a bug where mesh checking was not correct PiperOrigin-RevId: 697792885 --- jax/_src/lax/lax.py | 2 +- tests/pjit_test.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f434413834f7..ff9ac0a49578 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2096,11 +2096,11 @@ def broadcasting_sharding_rule(name, *avals): mesh = None for a in avals: if a.sharding is not None: - mesh = a.sharding.mesh if mesh is not None and mesh != a.sharding.mesh: raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') + mesh = a.sharding.mesh assert mesh is not None shapes = [aval.shape for aval in avals if aval.shape] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0c1c28809062..6df011419513 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4979,6 +4979,21 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) + def test_broadcasting_nary_error(self): + mesh1 = Mesh([jax.devices()[0]], 'x') + mesh2 = Mesh([jax.devices()[0]], 'y') + + arr1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) + arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) + + @jax.jit + def f(x, y): + return x + y + + with self.assertRaisesRegex( + ValueError, "Mesh for all inputs should be equal"): + f(arr1, arr2) + def test_sin_unop(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16.).reshape(8, 2) From 45c9c0a585704c0c139a33b838d6827b8d16df5e Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 18 Nov 2024 17:09:28 -0800 Subject: [PATCH 042/112] [pallas] Minor simplifications to Pallas interpreter. BlockMappings are always present now. PiperOrigin-RevId: 697807120 --- jax/_src/pallas/pallas_call.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f7bd0dd4e4d7..729d0e617a87 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -72,10 +72,6 @@ pallas_call_p.multiple_results = True def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): - if start_idx is None: - assert is_indexing is None - return value - assert is_indexing is not None start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, @@ -84,10 +80,6 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, is_indexing): - if start_idx is None: - assert is_indexing is None - return update - assert is_indexing is not None start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) broadcast_dims = tuple(i for i, b in enumerate(is_indexing) if not b) @@ -234,8 +226,7 @@ def _pallas_call_impl_interpret( for bm in grid_mapping.block_mappings ] block_shapes = [ - None if iid is None - else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) ] @@ -284,8 +275,9 @@ def body(carry): aval = jax_core.get_aval(s) s.aval = aval.update(dtype=jnp.int32) start_indices = [ - None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) - for bm in grid_mapping.block_mappings] + bm.compute_start_indices_interpret(loop_idx, *scalars) + for bm in grid_mapping.block_mappings + ] blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry_consts_ins, is_indexing_dim) with pallas_core.grid_env(local_grid_env): From c5e8ae80f9949c69bd6b99d245bf599be2644d7b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Nov 2024 09:46:22 -0500 Subject: [PATCH 043/112] Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs. Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827. Fixes #24875 --- CHANGELOG.md | 3 ++ jax/_src/scipy/special.py | 26 ++++++++++++++-- tests/lax_scipy_special_functions_test.py | 37 +++++++++++++++++------ 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 204df6a83e52..9082399c8695 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` on the function inputs. + * {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now + return NaN for negative integer inputs, to match the behavior of SciPy from + https://github.com/scipy/scipy/pull/21827. * `jax.clear_backends` was removed after being deprecated in v0.4.26. * New Features diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 605cde19b1e7..2fffe6381b97 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -66,6 +66,7 @@ def gammaln(x: ArrayLike) -> Array: return lax.lgamma(x) +@jit def gammasgn(x: ArrayLike) -> Array: r"""Sign of the gamma function. @@ -81,6 +82,13 @@ def gammasgn(x: ArrayLike) -> Array: Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Because :math:`\Gamma(x)` is never zero, no condition is required for this case. + * if :math:`x = -\infty`, NaN is returned. + * if :math:`x = \pm 0`, :math:`\pm 1` is returned. + * if :math:`x` is a negative integer, NaN is returned. The sign of gamma + at a negative integer depends on from which side the pole is approached. + * if :math:`x = \infty`, :math:`1` is returned. + * if :math:`x` is NaN, NaN is returned. + Args: x: arraylike, real valued. @@ -92,8 +100,14 @@ def gammasgn(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function """ x, = promote_args_inexact("gammasgn", x) + typ = x.dtype.type floor_x = lax.floor(x) - return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0) + x_negative = x < 0 + return jnp.select( + [(x_negative & (x == floor_x)) | jnp.isnan(x), + (x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))], + [typ(np.nan), typ(-1.0)], + typ(1.0)) def gamma(x: ArrayLike) -> Array: @@ -115,6 +129,13 @@ def gamma(x: ArrayLike) -> Array: \Gamma(n) = (n - 1)! + * if :math:`z = -\infty`, NaN is returned. + * if :math:`x = \pm 0`, :math:`\pm \infty` is returned. + * if :math:`x` is a negative integer, NaN is returned. The sign of gamma + at a negative integer depends on from which side the pole is approached. + * if :math:`x = \infty`, :math:`\infty` is returned. + * if :math:`x` is NaN, NaN is returned. + Args: x: arraylike, real valued. @@ -127,7 +148,8 @@ def gamma(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function Notes: - Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs. + Unlike the scipy version, JAX's ``gamma`` does not support complex-valued + inputs. """ x, = promote_args_inexact("gamma", x) return gammasgn(x) * lax.exp(lax.lgamma(x)) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index cb40ae291e76..5753628957c7 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -20,9 +20,11 @@ from absl.testing import parameterized import numpy as np +import scipy import scipy.special as osp_special import jax +import jax.numpy as jnp from jax._src import test_util as jtu from jax.scipy import special as lsp_special @@ -214,7 +216,7 @@ def partial_lax_op(*vals): n=[0, 1, 2, 3, 10, 50] ) def testScipySpecialFunBernoulli(self, n): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. scipy_op = lambda: osp_special.bernoulli(n).astype(dtype) lax_op = functools.partial(lsp_special.bernoulli, n) args_maker = lambda: [] @@ -222,16 +224,33 @@ def testScipySpecialFunBernoulli(self, n): self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5) def testGammaSign(self): - # Test that the sign of `gamma` matches at integer-valued inputs. - dtype = jax.numpy.zeros(0).dtype # default float dtype. - args_maker = lambda: [np.arange(-10, 10).astype(dtype)] - rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 - self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol) - self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol) + dtype = jnp.zeros(0).dtype # default float dtype. + typ = dtype.type + testcases = [ + (np.arange(-10, 0).astype(dtype), np.array([np.nan] * 10, dtype=dtype)), + (np.nextafter(np.arange(-5, 0).astype(dtype), typ(-np.inf)), + np.array([1, -1, 1, -1, 1], dtype=dtype)), + (np.nextafter(np.arange(-5, 0).astype(dtype), typ(np.inf)), + np.array([-1, 1, -1, 1, -1], dtype=dtype)), + (np.arange(0, 10).astype(dtype), np.ones((10,), dtype)), + (np.nextafter(np.arange(0, 10).astype(dtype), typ(np.inf)), + np.ones((10,), dtype)), + (np.nextafter(np.arange(1, 10).astype(dtype), typ(-np.inf)), + np.ones((9,), dtype)), + (np.array([-np.inf, -0.0, 0.0, np.inf, np.nan]), + np.array([np.nan, -1.0, 1.0, 1.0, np.nan])) + ] + for inp, out in testcases: + self.assertArraysEqual(out, lsp_special.gammasgn(inp)) + self.assertArraysEqual(out, jnp.sign(lsp_special.gamma(inp))) + if jtu.parse_version(scipy.__version__) >= (1, 15): + self.assertArraysEqual(out, osp_special.gammasgn(inp)) + self.assertAllClose(osp_special.gammasgn(inp), + lsp_special.gammasgn(inp)) def testNdtriExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.arange(-10, 10).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) @@ -239,7 +258,7 @@ def testNdtriExtremeValues(self): def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype), np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 From 0fe77bc9f0e5a7c78e3de6371cbbbc9a3a43bf5a Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 18 Nov 2024 18:06:36 -0800 Subject: [PATCH 044/112] [Mosaic TPU] Support relayout for mask vector We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished. PiperOrigin-RevId: 697823543 --- .../tpu/transforms/apply_vector_layout.cc | 49 ++++++++++++++++--- tests/pallas/tpu_ops_test.py | 18 +++++++ 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 2732b63d7638..8292a770a1c3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6314,6 +6314,14 @@ FailureOr> relayout(RewriteContext &ctx, return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } VectorType vty = v.getType(); + const bool is_mask = vty.getElementTypeBitWidth() == 1; + if (is_mask) { + if (src.bitwidth() != 32 || dst.bitwidth() != 32) { + return emitError(v.getLoc(), + "Not implemented: mask relayout with non-32 bitwidth in " + "vector layout"); + } + } { // Replication imposes a replication constraint on the *logical* value of // the vector: When moving along a replicated axis, all elements must be @@ -6347,6 +6355,31 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + if (is_mask) { + auto new_tile_ty = + getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape); + src_tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = + builder.create(tile->getLoc(), new_tile_ty, *tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI32Type()); + } + auto assemble_with_mask_check = [&](xla::Array &tiles, + bool use_implicit_shape = false) { + if (is_mask) { + auto zeros_tile = builder.create( + tiles.begin()->getLoc(), + DenseElementsAttr::get(cast(tiles.begin()->getType()), + builder.getI32IntegerAttr(0))); + tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = builder.create( + tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI1Type()); + } + return assemble(builder, vty, dst, tiles, target_shape, use_implicit_shape) + .getResult(); + }; // Two easy cases: source is more general, or is replicated. if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with @@ -6397,9 +6430,8 @@ FailureOr> relayout(RewriteContext &ctx, .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { @@ -6410,8 +6442,7 @@ FailureOr> relayout(RewriteContext &ctx, xla::Array dst_tiles( /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), /*value=*/src_tiles.data()[0]); - return assemble(builder, vty, dst, std::move(dst_tiles), target_shape) - .getResult(); + return assemble_with_mask_check(dst_tiles); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit @@ -6449,9 +6480,8 @@ FailureOr> relayout(RewriteContext &ctx, dst.offsets())); CHECK_EQ(src, dst); // At this point we've should be done. - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } // TODO(apaszke): Implement a debug mode that inserts additional assertions. @@ -6491,6 +6521,9 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); + if (*lo == *li) { + continue; + } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN( Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index ca5361a70051..8843c6a58064 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -233,6 +233,24 @@ def run(cond, lhs, rhs): assert (run(cond, lhs, rhs) == lhs).all() + def test_logical_and_relayouted_mask(self): + def get_mask(x_ref): + x = x_ref[...] == 1 + iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1) + iota = iota > 7 + return jnp.logical_and(x, iota) + + def body(x_ref, y_ref): + y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0) + + shape = (2, 512) + out = jax.ShapeDtypeStruct(shape, jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape) + result = self.pallas_call(body, out_shape=out)(x) + expected = jnp.ones(x.shape, dtype=jnp.float32) + expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) + np.testing.assert_array_equal(result, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True From d397dd968468dc054b91aacd1958a8586c409878 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 18 Nov 2024 23:58:40 -0800 Subject: [PATCH 045/112] Implement lax.pad in Pallas. PiperOrigin-RevId: 697897093 --- jax/_src/pallas/mosaic/lowering.py | 65 +++++++++++++++++++ .../tpu/transforms/apply_vector_layout.cc | 10 ++- tests/pallas/ops_test.py | 58 +++++++++++++++-- 3 files changed, 127 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3dbb410be29f..be4102dff716 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3243,3 +3243,68 @@ def _lower_fun(shape): lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering + + +def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): + operand, padding_value = args + padding_config = kwargs["padding_config"] + + out_type: ir.VectorType = aval_to_ir_type(ctx.avals_in[0]) + if not isinstance(out_type, ir.VectorType): + raise NotImplementedError("Only vector types are supported.") + + for axis, (low, high, interior) in enumerate(padding_config): + if low == 0 and high == 0 and interior == 0: + continue + + def _pad(val): + shape = list(operand.type.shape) + shape[axis] = val + pad_vec_type = ir.VectorType.get( + shape, + operand.type.element_type, + ) + + if isinstance(padding_value, ir.OpResult): + pad = vector.BroadcastOp( + pad_vec_type, + padding_value, + ).result + else: + scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) + pad = arith.ConstantOp( + pad_vec_type, + ir.DenseElementsAttr.get_splat( + pad_vec_type, + scalar_attr, + ), + ).result + return pad + + if low != 0: + pad_low = _pad(low) + new_shape = out_type.shape + new_shape[axis] += low + out_type = ir.VectorType.get( + new_shape, + out_type.element_type, + ) + operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis) + + if high != 0: + pad_high = _pad(high) + new_shape = out_type.shape + new_shape[axis] += high + out_type = ir.VectorType.get( + new_shape, + out_type.element_type, + ) + operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis) + + if interior > 0: + raise NotImplementedError("Not implemented: interior padding") + + return operand + + +lowering_rules[lax.pad_p] = _pad_lowering_rule diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8292a770a1c3..4a344fa9d427 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2674,6 +2674,13 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, for (size_t i = 0; i < operand_vregs.size(); ++i) { auto &vreg = operand_vregs[i]; const auto &layout = layouts_in[i]; + const int packing = res_layout->packing(); + + if (layout->tiling()[0] % packing != 0) { + return op.emitOpError( + "Illegal tiling: Non-native tiling in concat - this should " + "have been caught earlier!"); + } const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; if (operand_offset != 0) { @@ -2685,7 +2692,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, } const auto bitwidth = res_ty.getElementTypeBitWidth(); - const int packing = res_layout->packing(); SmallVector out_idx; vreg.Each([&](absl::Span idx, Value *v) { out_idx.assign(idx.begin(), idx.end()); @@ -2716,7 +2722,7 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, mask = builder.create( op.getLoc(), vmask_ty, ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(layout->tiling()[0]), + ArrayRef{boundIdxConst(layout->tiling()[0] / packing), boundIdxConst(operand_offset)}); } // Blend the current value with the existing value in the output. diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index df48da776e5f..9f0b9aef5af3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -21,12 +21,9 @@ from typing import Any import unittest -import numpy as np from absl.testing import absltest from absl.testing import parameterized - import jax -import jax.numpy as jnp from jax import lax from jax import random from jax._src import config @@ -34,8 +31,10 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu -from jax.interpreters import partial_eval as pe from jax.experimental import pallas as pl +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np if sys.platform != "win32": from jax.experimental.pallas import triton as plgpu @@ -1980,6 +1979,57 @@ def convert(x_ref, y_ref): y_ref = jax.lax.bitcast_convert_type(x, out_dtype) np.testing.assert_array_equal(y, y_ref) + @parameterized.product( + array_shapes=[(4, 128), (10, 100), (8, 128), (17, 257)], + padding=[ + ((5, 8), (0, 0)), + ((0, 0), (5, 100)), + ((1, 1), (1, 1)), + ((0, 0), (0, 0)), + ], + pad_type=["constant", "wrap"], + dtype=( + jnp.float32, + jnp.bfloat16, + ), + ) + def test_arbitrary_padding_jnp_pad( + self, array_shapes, padding, pad_type, dtype + ): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not implemented on GPU") + + x = jnp.arange(np.prod(array_shapes), dtype=dtype).reshape(array_shapes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.pad(x_ref[...], padding, mode=pad_type) + + ref = jnp.pad(x, padding, mode=pad_type) + + out_shape = jax.ShapeDtypeStruct(ref.shape, x.dtype) + try: + out = self.pallas_call( + kernel, + out_shape=out_shape, + )(x) + np.testing.assert_array_equal(out, jnp.pad(x, padding, mode=pad_type)) + except Exception as e: + self.assertEqual( + dtype, + jnp.bfloat16, + "some bfloat16 combinations can fail with not implemented", + ) + # The first two options are expected to fail due to current limitations + # in the Pallas TPU lowering. However, the last one is unexpected, and + # should be fixed, it is a pjrt bug. + # b/379787665 + acceptable_errors = ( + "Only 32-bit types supported" in str(e) + or "Not implemented" in str(e) + or "Expected mask vector type" in str(e) + ) + self.assertTrue(acceptable_errors, "Failed with error: " + str(e)) + class OpsInterpretTest(OpsTest): INTERPRET = True From da50ad7ee395eec84930d7a1c87346a547b0ae07 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 19 Nov 2024 01:47:13 -0800 Subject: [PATCH 046/112] [AutoPGLE] Use compile options to override debug options instead of XLA_FLAGS. PiperOrigin-RevId: 697924164 --- tests/pgle_test.py | 327 +++++++++++++++++++++------------------------ 1 file changed, 153 insertions(+), 174 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 609ca38fd7a5..46146abfc7c6 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import ExitStack from functools import partial import glob import logging @@ -43,41 +42,7 @@ @jtu.pytest_mark_if_available('multiaccelerator') -# TODO(patrios): Remove this skip once b/379267258 is fixed. -@jtu.skip_under_pytest( - 'This test requires specific XLA_FLAGS. However, pytest does not reload ' - 'modules between tests. So if another test is launched before this one ' - 'necessary XLA_FLAGS will not be re-used by the XLA.') class PgleTest(jtu.JaxTestCase): - _dump_exit_stack: ExitStack | None = None - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._dump_exit_stack = ExitStack() - - cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory()) - if 'XLA_FLAGS' in os.environ: - cls.old_xla_flags = os.environ['XLA_FLAGS'] - else: - cls.old_xla_flags = None - - os.environ['XLA_FLAGS'] = ( - f'--xla_dump_to={cls.dump_dir}' - ' --xla_gpu_experimental_dump_fdo_profiles=true' - ' --xla_gpu_enable_latency_hiding_scheduler=true' - # TODO(patrios): Remove this flag once b/376647494 is fixed. - ' --xla_gpu_graph_level=0' - ) - if cls.old_xla_flags: - os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags - - @classmethod - def tearDownClass(cls): - if cls.old_xla_flags: - os.environ['XLA_FLAGS'] = cls.old_xla_flags - cls._dump_exit_stack.close() - super().tearDownClass() def setUp(self): super().setUp() @@ -85,12 +50,6 @@ def setUp(self): cc.reset_cache() def tearDown(self): - # Cleanup dump directory - for file in os.listdir(self.dump_dir): - file_path = os.path.join(self.dump_dir, file) - if os.path.isfile(file_path): - os.remove(file_path) - cc.set_cache_dir(None) super().tearDown() @@ -101,6 +60,7 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y @@ -130,6 +90,11 @@ def testPGLEProfilerGetFDOProfileLarge(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + }, ) def f(x): agg = x @@ -154,6 +119,11 @@ def testAutoPgle(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + }, ) def f(x): return x * 2 @@ -172,7 +142,7 @@ def f(x): # Run 2: Second PGLE run should not recompile the module with jtu.count_cached_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 0) + self.assertLess(cache_miss_count[0], 2) # Run 3: The module should be recompiled with FDO profiles with jtu.count_cached_compilation_cache_miss() as cache_miss_count: @@ -182,7 +152,7 @@ def f(x): # Run 4: Fast-path should be used after PGLE is done with jtu.count_cached_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 0) + self.assertLess(cache_miss_count[0], 2) def testAutoPgleWithAot(self): @jax.jit @@ -211,145 +181,154 @@ def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - ) - def f(x): - agg = x - for _ in range(its): - agg = agg @ x - return agg - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - - with (config.enable_compilation_cache(True), - config.enable_pgle(True), - config.raise_persistent_cache_errors(True), - config.raise_persistent_cache_errors(True), - config.persistent_cache_min_entry_size_bytes(0), - config.persistent_cache_min_compile_time_secs(0), - config.pgle_profiling_runs(2), - tempfile.TemporaryDirectory() as cache_dir): - cc.reset_cache() - cc.set_cache_dir(cache_dir) - # Run 1: Module should be compiled without FDO - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with tempfile.TemporaryDirectory() as dump_dir: + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True' + }, + ) + def f(x): + agg = x + for _ in range(its): + agg = agg @ x + return agg + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + + with (config.enable_compilation_cache(True), + config.enable_pgle(True), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + config.pgle_profiling_runs(2), + tempfile.TemporaryDirectory() as cache_dir): + cc.reset_cache() + cc.set_cache_dir(cache_dir) + # Run 1: Module should be compiled without FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + # Non-pgle profiled version of module should be saved + non_pgle_profiled_files = os.listdir(cache_dir) + self.assertNotEmpty(non_pgle_profiled_files) + + # Run 2: Compilation should not be called + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertLess(cache_miss_count[0], 2) + + module_before_pgle = os.listdir(dump_dir) + self.assertNotEmpty(module_before_pgle) + # Run 3: Module should be compiled with FDO and stored to persistent cache + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + # Check if FDO profile file of the biggest module is not empty + module_after_pgle = [ + x + for x in os.listdir(dump_dir) + if x not in module_before_pgle + ] + self.assertNotEmpty(module_after_pgle) + biggest_module_after_pgle = max( + module_after_pgle, + key=lambda x: os.path.getsize( + os.path.join(dump_dir, x) + ), + ) + base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) + + # Check if FDO profile file in dump directory is not empty + for module in module_after_pgle: + if module.startswith(base_module_name) and module.endswith( + '.fdo_profile' + ): + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, module)), 0 + ) + + for pgle_profiler in pjit._pgle_profiler_dict.values(): + self.assertTrue(pgle_profiler.is_enabled()) + self.assertTrue(pgle_profiler.is_fdo_consumed()) + + files_after_pgle_profile = os.listdir(cache_dir) + self.assertGreater( + len(files_after_pgle_profile), len(non_pgle_profiled_files) + ) + + # Removing non-pgle profiled module from cache to check that later pgle + # profiled version will be used. + for non_pgle_file in non_pgle_profiled_files: + path = os.path.join(cache_dir, non_pgle_file) + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + # Run 4: Persistent compilation cache should be hit PGLE profiler should + # be disabled + cache_hit = 0 + def check_if_cache_hit(event): + nonlocal cache_hit + if event == '/jax/compilation_cache/cache_hits': + cache_hit += 1 + + monitoring.register_event_listener(check_if_cache_hit) f(x) - self.assertGreater(cache_miss_count[0], 0) + monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - # Non-pgle profiled version of module should be saved - non_pgle_profiled_files = os.listdir(cache_dir) - self.assertNotEmpty(non_pgle_profiled_files) + self.assertGreater(cache_hit, 0) - # Run 2: Compilation should not be called - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - f(x) - self.assertEqual(cache_miss_count[0], 0) + def testPassingFDOProfile(self): + mesh = jtu.create_mesh((2,), ('x',)) - module_before_pgle = os.listdir(self.dump_dir) - self.assertNotEmpty(module_before_pgle) - # Run 3: Module should be compiled with FDO and stored to persistent cache - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - f(x) - self.assertGreater(cache_miss_count[0], 0) - - # Check if FDO profile file of the biggest module is not empty - module_after_pgle = [ - x - for x in os.listdir(self.dump_dir) - if x not in module_before_pgle - ] - self.assertNotEmpty(module_after_pgle) - biggest_module_after_pgle = max( - module_after_pgle, - key=lambda x: os.path.getsize( - os.path.join(self.dump_dir, x) - ), - ) - base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) - - # Check if FDO profile file in dump directory is not empty - for module in module_after_pgle: - if module.startswith(base_module_name) and module.endswith( - '.fdo_profile' - ): - self.assertGreater( - os.path.getsize(os.path.join(self.dump_dir, module)), 0 - ) - - for pgle_profiler in pjit._pgle_profiler_dict.values(): - self.assertTrue(pgle_profiler.is_enabled()) - self.assertTrue(pgle_profiler.is_fdo_consumed()) - - files_after_pgle_profile = os.listdir(cache_dir) - self.assertGreater( - len(files_after_pgle_profile), len(non_pgle_profiled_files) + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) + def f(x, y): + return x @ y - # Removing non-pgle profiled module from cache to check that later pgle - # profiled version will be used. - for non_pgle_file in non_pgle_profiled_files: - path = os.path.join(cache_dir, non_pgle_file) - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) - - api.clear_caches() - pjit._pgle_profiler_dict.clear() - - # Run 4: Persistent compilation cache should be hit PGLE profiler should - # be disabled - cache_hit = 0 - def check_if_cache_hit(event): - nonlocal cache_hit - if event == '/jax/compilation_cache/cache_hits': - cache_hit += 1 - - monitoring.register_event_listener(check_if_cache_hit) - f(x) - monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - - self.assertGreater(cache_hit, 0) - - def testPassingFDOProfile(self): - mesh = jtu.create_mesh((2,), ('x',)) - - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - ) - def f(x, y): - return x @ y - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - y = x + 1 + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 - with config.pgle_profiling_runs(0): - f_lowered = f.lower(x, y) - compiled = f_lowered.compile() + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() - with tempfile.TemporaryDirectory() as cache_dir: - jax.profiler.start_trace(cache_dir) - compiled(x, y) - jax.profiler.stop_trace() - directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) - directories = [d for d in directories if os.path.isdir(d)] - rundir = directories[-1] - logging.info('rundir: %s', rundir) - fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) - - if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): - self.assertIn(b'custom', fdo_profile) - - logging.info('fdo_profile: %s', fdo_profile) - # Test pass fdo_profile as compiler_options API works. - f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) + with tempfile.TemporaryDirectory() as cache_dir: + jax.profiler.start_trace(cache_dir) + compiled(x, y) + jax.profiler.stop_trace() + directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) + directories = [d for d in directories if os.path.isdir(d)] + rundir = directories[-1] + logging.info('rundir: %s', rundir) + fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) + + if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): + self.assertIn(b'custom', fdo_profile) + + logging.info('fdo_profile: %s', fdo_profile) + # Test pass fdo_profile as compiler_options API works. + f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) if __name__ == '__main__': From 1458d3dd562c6c1004c9ac1162de731fed91ec68 Mon Sep 17 00:00:00 2001 From: nireekshak Date: Tue, 19 Nov 2024 15:04:55 +0000 Subject: [PATCH 047/112] Fix some typos --- docs/Custom_Operation_for_GPUs.md | 6 +++--- docs/advanced-autodiff.md | 5 ++--- docs/autodidax.ipynb | 2 +- docs/autodidax.md | 2 +- docs/autodidax.py | 2 +- docs/notebooks/shard_map.ipynb | 2 +- docs/notebooks/shard_map.md | 2 +- 7 files changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index f4b61cbcf7dc..2163272e2542 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the gradient. (And if you implement the interface to support vmat, it will also be on the outer primitive). -JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic. +JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic. XLA sharding goes in two phases: a sharding propagation phase and a partition phase. -The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph. +The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph. For XLA to be able to shard our custom operations, it needs us to define 2 extra functions: infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively. The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding. The partition() function will do a few things: -- tell which input sharding will be expected. XLA will reshad if needed. +- tell which input sharding will be expected. XLA will reshard if needed. - tell the final version of the output sharding. - give a function that will create the new instruction from the sharded inputs. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index 023dc8040954..c56e82c77450 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -350,7 +350,7 @@ This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \math and so on. -To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of the two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. +To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. ## How it's made: Two foundational autodiff functions @@ -475,7 +475,7 @@ where we use `CT a` to denote the type for the cotangent space for `a`. In words This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. -There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). +There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). @@ -1762,7 +1762,6 @@ print(grad(app, 1)(lambda x: x ** 2, 4.)) Refer to `fixed_point` above for another usage example. **You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments. -s ## Next steps diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 8b418b16f878..e620967de4b7 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2797,7 +2797,7 @@ "representing unknown outputs, we need avals, which we get from the abstract\n", "eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n", "`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n", - "weakrefs.)\n", + "`weakref`s.)\n", "\n", "That `process_primitive` logic applies to most primitives, but `xla_call_p`\n", "requires recursive treatment. So we special-case its rule in a\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 9e726e5ed82e..1c16db80f608 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -2195,7 +2195,7 @@ output. If instead any input is unknown then we instead stage out into a representing unknown outputs, we need avals, which we get from the abstract eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using -weakrefs.) +`weakref`s.) That `process_primitive` logic applies to most primitives, but `xla_call_p` requires recursive treatment. So we special-case its rule in a diff --git a/docs/autodidax.py b/docs/autodidax.py index f57af2cd96f2..f74617f31416 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -2187,7 +2187,7 @@ def full_lower(self): # representing unknown outputs, we need avals, which we get from the abstract # eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and # `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using -# weakrefs.) +# `weakref`s.) # # That `process_primitive` logic applies to most primitives, but `xla_call_p` # requires recursive treatment. So we special-case its rule in a diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 37c27ce2728a..d73b0d4c0f3e 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -864,7 +864,7 @@ "Indeed, this implementation is often used on both TPU and GPU!\n", "\n", "The reason `psum_scatter` can require about half the communication as a full\n", - "`psum` is illustrated the `ppermute` section.\n", + "`psum` is illustrated in the `ppermute` section.\n", "\n", "Another intuition is that we can use `psum_scatter` to implement a distributed\n", "matrix multiplication with inputs and outputs sharded over the same axis. In\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 47b11079e27d..c52cf0e6d22b 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -627,7 +627,7 @@ def psum(x, axis_name): Indeed, this implementation is often used on both TPU and GPU! The reason `psum_scatter` can require about half the communication as a full -`psum` is illustrated the `ppermute` section. +`psum` is illustrated in the `ppermute` section. Another intuition is that we can use `psum_scatter` to implement a distributed matrix multiplication with inputs and outputs sharded over the same axis. In From d912034cb5e6c5584255621b34f958f2846d1d11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 19 Nov 2024 16:42:19 +0100 Subject: [PATCH 048/112] fix(docs): typos in macro name chore(docs): sync .md file --- docs/ffi.ipynb | 4 ++-- docs/ffi.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index f1a699b5c56c..72a2a6914fc0 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -139,8 +139,8 @@ "}\n", "\n", "// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare\n", - "// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`\n", - "// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.\n", + "// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`\n", + "// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.\n", "XLA_FFI_DEFINE_HANDLER_SYMBOL(\n", " RmsNorm, RmsNormImpl,\n", " ffi::Ffi::Bind()\n", diff --git a/docs/ffi.md b/docs/ffi.md index dbe901237ed4..96b627675004 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -134,8 +134,8 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, } // Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare -// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` -// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. +// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL` +// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`. XLA_FFI_DEFINE_HANDLER_SYMBOL( RmsNorm, RmsNormImpl, ffi::Ffi::Bind() From 3556a8333443228d341245ee59278c7c93e22238 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 19 Nov 2024 09:52:15 -0800 Subject: [PATCH 049/112] Add missing version guard in GPU tests for jnp.poly. jaxlib v0.4.35 is required for running `jnp.linalg.eig` on GPU which is required for `poly`. PiperOrigin-RevId: 698052642 --- tests/lax_numpy_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7aad5634775d..ef80e368c9c7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,6 +51,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal +from jax._src.lib import version as jaxlib_version from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -1494,6 +1495,8 @@ def testPoly(self, a_shape, dtype, rank): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") + if rank == 2 and jaxlib_version <= (0, 4, 35) and jtu.test_device_matches(["gpu"]): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): From 6c31efa3f324a810461389f728ab848abffd767f Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 19 Nov 2024 10:32:28 -0800 Subject: [PATCH 050/112] [Mosaic TPU] Add general tpu.vector_store and support masked store. This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore. This cl also adds the support for (dynamic) masked store. PiperOrigin-RevId: 698067741 --- jaxlib/mosaic/dialect/tpu/tpu.td | 16 +++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 27 ++++++- .../tpu/transforms/apply_vector_layout.cc | 71 ++++++++++++++----- .../tpu/transforms/infer_vector_layout.cc | 15 +++- 4 files changed, 107 insertions(+), 22 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index b312bca7a7d3..4fd960063dc4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -214,6 +214,22 @@ def TPU_LoadOp : TPU_Op<"load"> { }]; } +// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. +def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { + let arguments = (ins + AnyVector:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; +} + def TPU_StridedLoadOp : TPU_Op<"strided_load"> { let arguments = (ins AnyMemRef:$base, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 6f690f6a0fcb..96b78c8caf37 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -440,6 +440,31 @@ LogicalResult StridedStoreOp::verify() { getValueToStore().getType()); } +LogicalResult VectorStoreOp::verify() { + if (!getStrides().empty()) { + return emitError("Not implemented: general vector store with strides."); + } + VectorType value_ty = getValueToStore().getType(); + MemRefType ref_ty = getBase().getType(); + + if (value_ty.getElementType() != ref_ty.getElementType()) { + return emitOpError( + "Expected base and valueToStore element type should match"); + } + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " indices"; + } + if (getMask()) { + if (value_ty.getElementTypeBitWidth() != 32) { + return emitError( + "Not implemented: masked store with non-32-bit element type"); + } + if (value_ty.getShape() != getMask().getType().getShape()) + return emitOpError("Expected valueToStore shape to match mask shape"); + } + return success(); +} + LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); @@ -468,7 +493,7 @@ LogicalResult verifyRotateOp(Op op) { } if (op.getStride().has_value() != op.getStrideDimension().has_value()) { op.emitOpError( - "Expected either none or both stride and stride dimension are " + "Expected either none or both stride and stride dimension are " "present"); return failure(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 4a344fa9d427..8ade7450881a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4200,18 +4200,15 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, shape_cast_op->erase(); return success(); } -LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - MLIRContext *const mlir_ctx = op.getContext(); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), - [&](const Layout &l) { return l.has_value(); })); + +template +LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, + const VectorLayout &to_store_layout, + TypedValue store_mask = nullptr) { + Operation &op = *(store_op.getOperation()); + MLIRContext *const mlir_ctx = store_op.getContext(); ImplicitLocOpBuilder builder(op.getLoc(), &op); - vector::StoreOp store_op = cast(op); const VectorType ty = store_op.getValueToStore().getType(); - const VectorLayout &to_store_layout = *layouts_in.front(); const auto memref_ty = getMemRefType(store_op.getBase()); if (!ty.getRank()) { return op.emitOpError("Not implemented: scalar stores to vmem"); @@ -4308,10 +4305,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } else { // Convert dynamic store to dynamic slice + static store. This saves us a // bunch of scalar core work. - auto slice_result = - sliceRef(builder, store_op.getBase(), - store_op.getVectorType().getShape(), store_op.getIndices(), - ArrayRef(memref_tiling).take_back(tiled_dims)); + auto slice_result = sliceRef( + builder, store_op.getBase(), ty.getShape(), store_op.getIndices(), + ArrayRef(memref_tiling).take_back(tiled_dims)); if (failed(slice_result)) { return failure(); } @@ -4332,6 +4328,13 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, xla::Array tiles, disassemble(builder, to_store_layout, store_op.getValueToStore(), ctx.target_shape)); + std::optional> tile_masks; + if (store_mask) { + FAILUREOR_ASSIGN_OR_RETURN( + tile_masks, + disassemble(builder, to_store_layout, store_mask, ctx.target_shape)); + TPU_ASSERT_EQ_OP(tile_masks->dimensions(), tiles.dimensions()); + } const int64_t ndims = ty.getRank(); const auto base_s = is_1d ? IdxConst(0, builder, op.getLoc()) : tile_base_idxs.front(); @@ -4353,6 +4356,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, const absl::Status status = tiles.EachStatus([&](const absl::Span idx, const Value tile) -> absl::Status { + const auto tile_mask = store_mask ? (*tile_masks)(idx) : nullptr; const std::unique_ptr bounds = to_store_layout.tileDataBounds(mlir_ctx, stored_shape, toArrayRef(idx), ctx.target_shape); @@ -4412,19 +4416,19 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, updated = builder.create(mask, tile, data); } builder.create( - updated, base_addr, indices, sublane_mask, - /*mask=*/nullptr, + updated, base_addr, indices, sublane_mask, tile_mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } else { builder.create( tile, base_addr, indices, sublane_mask, - /*mask=*/mask, + tile_mask + ? builder.create(mask, tile_mask).getResult() + : mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } } else { builder.create( - tile, base_addr, indices, sublane_mask, - /*mask=*/nullptr, + tile, base_addr, indices, sublane_mask, tile_mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } return absl::OkStatus(); @@ -4434,7 +4438,35 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } store_op->erase(); return success(); +} + +LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto store_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), + [&](const Layout &l) { return l.has_value(); })); + return vector_store_impl(ctx, store_op, *layouts_in.front()); +} + +LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto store_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + TPU_ASSERT_OP(layouts_in.front().has_value()); + auto other_layouts_in = layouts_in.drop_front(); + if (store_op.getMask()) { + TPU_ASSERT_EQ_OP(layouts_in.front(), layouts_in.back()); + other_layouts_in = other_layouts_in.drop_back(); } + TPU_ASSERT_OP(llvm::none_of(other_layouts_in, + [&](const Layout &l) { return l.has_value(); })); + return vector_store_impl(ctx, store_op, *layouts_in.front(), + store_op.getMask()); +} LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, @@ -4648,6 +4680,7 @@ const llvm::StringMap &rules() { {tpu::StoreOp::getOperationName(), tpu_store_rule}, {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, + {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, {tpu::RegionOp::getOperationName(), tpu_region_rule}, {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 30486b6e995c..d84e4b883172 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -321,8 +321,14 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (inferStore(op, + /*has_mask=*/op.getMask() != nullptr) + .failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { + if (inferStore(op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -1540,7 +1546,8 @@ class VectorLayoutInferer { return failure(); } - LogicalResult infer(vector::StoreOp op) { + template + LogicalResult inferStore(Op op, bool has_mask = false) { auto ref_ty = getMemRefType(op.getBase()); auto store_ty = op.getValueToStore().getType(); TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(), @@ -1648,6 +1655,10 @@ class VectorLayoutInferer { } SmallVector in_layout{store_layout}; in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout); + if (has_mask) { + // Mask layout should be the same as the layout of value to store. + in_layout.push_back(store_layout); + } setInLayout(op, in_layout); return success(); } From c44f11d15e60ccb27d9c21a13a5e789ebded7713 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 19 Nov 2024 11:25:51 -0800 Subject: [PATCH 051/112] Add alternate implementation of threefry as a pallas kernel. Current restrictions: 1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes. 2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1). 3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM). PiperOrigin-RevId: 698086352 --- .../pallas/ops/tpu/random/threefry.py | 156 ++++++++++++++++++ tests/pallas/BUILD | 4 + tests/pallas/tpu_pallas_random_test.py | 51 ++++++ 3 files changed, 211 insertions(+) create mode 100644 jax/experimental/pallas/ops/tpu/random/threefry.py diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py new file mode 100644 index 000000000000..d1e6bf1fd93d --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -0,0 +1,156 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the Threefry PRNG as a Pallas kernel.""" +from typing import Sequence +import jax +from jax import lax +from jax._src import prng +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + +Shape = Sequence[int] + +BLOCK_SIZE = (256, 256) + +_round_up = lambda x, y: (x + y - 1) // y * y + + +def blocked_iota(block_shape: Shape, + total_shape: Shape): + """Computes a sub-block of a larger shaped iota. + + Args: + block_shape: The output block shape of the iota. + total_shape: The total shape of the input tensor. + Returns: + Result of the blocked iota. + """ + iota_data = jnp.zeros(block_shape, dtype=jnp.uint32) + multiplier = 1 + for dim in range(len(block_shape)-1, -1, -1): + block_mult = 1 + counts_lo = lax.broadcasted_iota( + dtype=jnp.uint32, shape=block_shape, dimension=dim + ) + iota_data += counts_lo * multiplier * block_mult + multiplier *= total_shape[dim] + return iota_data + + +def _compute_scalar_offset(iteration_index, + total_size: Shape, + block_size: Shape): + ndims = len(iteration_index) + dim_size = 1 + total_idx = 0 + for i in range(ndims-1, -1, -1): + dim_idx = iteration_index[i] * block_size[i] + total_idx += dim_idx * dim_size + dim_size *= total_size[i] + return total_idx + + +def threefry_2x32_count(key, + shape: Shape, + unpadded_shape: Shape, + block_size: tuple[int, int]): + """Generates random bits using the Threefry hash function. + + This function is a fusion of prng.shaped_iota and prng.threefry_2x32 from + the JAX core library. + + Args: + key: A threefry key of shape (2,). + shape: The shape of the output. Must be divisible by `block_size`. + unpadded_shape: If `shape` is padded, then this is the shape of the + output tensor if it were not padded. This is important for indexing + calculations within the kernel. If `shape` is not padded, then this + should be equal to `shape`. + block_size: The block size of the kernel. + + Returns: + A tensor of random bits of shape `shape`. + """ + shape = tuple(shape) + if np.prod(shape) > jnp.iinfo(jnp.uint32).max: + raise ValueError( + f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}") + + if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0): + raise ValueError( + f"Shape dimension {shape[-2:]} must be divisible by {block_size}") + grid_dims = shape[:-2] + ( + shape[-2] // block_size[-2], shape[-1] // block_size[1],) + + def kernel(key_ref, out_ref): + counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims))) + offset = _compute_scalar_offset(counts_idx, unpadded_shape, block_shape) + counts_lo = blocked_iota(block_size, unpadded_shape) + counts_lo = counts_lo + offset + counts_lo = counts_lo.astype(jnp.uint32) + # TODO(justinfu): Support hi bits on count. + counts_hi = jnp.zeros_like(counts_lo) + k1 = jnp.reshape(key_ref[0, 0], (1, 1)) + k2 = jnp.reshape(key_ref[0, 1], (1, 1)) + o1, o2 = prng.threefry2x32_p.bind( + k1, k2, counts_hi, counts_lo) + out_bits = o1 ^ o2 + out_ref[...] = out_bits.reshape(out_ref.shape) + + key = key.reshape((1, 2)) + out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32) + block_shape = (1,) * (len(shape)-2) + block_size + result = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs), + grid=grid_dims, + out_shape=out, + )(key) + return result + +def plthreefry_random_bits(key, bit_width: int, shape: Shape): + if bit_width != 32: + raise ValueError("Only 32-bit PRNG supported.") + if len(shape) == 0: + return plthreefry_random_bits(key, bit_width, (1, 1))[0, 0] + elif len(shape) == 1: + return plthreefry_random_bits(key, bit_width, (1, *shape))[0] + + requires_pad = ( + shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0) + if requires_pad: + padded_shape = tuple(shape[:-2]) + ( + _round_up(shape[-2], BLOCK_SIZE[-2]), + _round_up(shape[-1], BLOCK_SIZE[-1]), + ) + padded_result = threefry_2x32_count( + key, padded_shape, shape, block_size=BLOCK_SIZE) + return padded_result[..., :shape[-2], :shape[-1]] + else: + return threefry_2x32_count(key, shape, shape, block_size=BLOCK_SIZE) + + +plthreefry_prng_impl = prng.PRNGImpl( + key_shape=(2,), + seed=prng.threefry_seed, + split=prng.threefry_split, + random_bits=plthreefry_random_bits, + fold_in=prng.threefry_fold_in, + name="pallas_threefry2x32", + tag="plfry") + +prng.register_prng(plthreefry_prng_impl) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 92cab875df7d..50c1054ba9fd 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -394,9 +394,13 @@ jax_multiplatform_test( "tpu_pallas_random_test.py", ], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5p_2x2", + ], deps = [ "//jax:pallas", "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", ] + py_deps("absl/testing") + py_deps("numpy"), ) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index 2b5c315263c9..88c33a020ce9 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -20,10 +20,14 @@ from jax._src import test_util as jtu from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl +from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401 import jax.numpy as jnp import numpy as np +P = jax.sharding.PartitionSpec + jax.config.parse_flags_with_absl() @@ -253,6 +257,53 @@ def body(key_ref, o_ref): ) np.testing.assert_array_equal(result, jax_result) + @parameterized.parameters( + ((512, 512),), + ((137, 275),), # Non block-aligned shape + ((4, 512, 512),), # Greater than 2D shape + ((34,),), # 1D + (tuple(),), # 0D + ) + def test_threefry_kernel_matches_jax_threefry(self, shape): + with jax.threefry_partitionable(True): + key_jax = jax_random.key(0, impl="threefry2x32") + jax_gen = jax_random.bits(key_jax, shape=shape) + key_pl = jax_random.key(0, impl="pallas_threefry2x32") + pl_gen = jax_random.bits(key_pl, shape=shape) + + np.testing.assert_array_equal(jax_gen, pl_gen) + + @parameterized.parameters( + ((256, 256),), + ((35, 113),), # Non block-aligned shape + ((331,),), # 1D + ) + def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): + if jax.device_count() < 2: + self.skipTest("Need at least 2 devices") + num_devices = jax.device_count() + partition = P("x") + mesh = jax.make_mesh((num_devices,), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + with jax.threefry_partitionable(True): + key_jax = jax_random.split( + jax_random.key(0, impl="threefry2x32"), num_devices) + key_pallas = jax_random.split( + jax_random.key(0, impl="pallas_threefry2x32"), num_devices) + key_jax = jax.device_put(key_jax, sharding) + key_pallas = jax.device_put(key_pallas, sharding) + generate = shard_map.shard_map( + lambda x: jax_random.bits(x[0], shape=shape), + mesh=mesh, + in_specs=partition, + out_specs=partition, + ) + jax_gen = generate(key_jax) + pl_gen = generate(key_pallas) + + np.testing.assert_array_equal(jax_gen, pl_gen) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From a59bbb7cd721cc146a499e9ef37577e94a7357fb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 19 Nov 2024 11:59:59 -0800 Subject: [PATCH 052/112] Add test utility for accessing jaxlib version tuple. We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests. PiperOrigin-RevId: 698097487 --- jax/_src/test_util.py | 5 +++++ tests/compilation_cache_test.py | 5 ++--- tests/linalg_test.py | 13 ++++++------- tests/magma_linalg_test.py | 7 +++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 527d7a46ed13..c5a713743fb8 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -44,6 +44,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes +from jax._src import lib as _jaxlib from jax._src import linear_util as lu from jax._src import monitoring from jax._src import pjit as pjit_lib @@ -451,6 +452,10 @@ def assert_num_jit_and_pmap_compilations(times): f"but executed {count[0]}") +def jaxlib_version() -> tuple[int, ...]: + return _jaxlib.version + + def device_under_test(): return _TEST_DUT.value or xla_bridge.get_backend().platform diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index d10558afbe16..0f949aaf1490 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -41,7 +41,6 @@ from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client as xc -from jax._src.lib import version as jaxlib_version from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -538,7 +537,7 @@ def test_backend_serialization_deserialization(self): executable.fingerprint, deserialized_executable.fingerprint) def test_persistent_cache_enable_xla_caches(self): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("Test requires AutotuneCacheMode bindings") with config.compilation_cache_dir("jax-cache"): with config.persistent_cache_enable_xla_caches("none"): @@ -609,7 +608,7 @@ def test_tasks_disable_cache_metric(self): self.assertEqual(count_after_second_use, count_after_first_use) def test_persistent_cache_enable_xla_caches_disabled(self): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("Test requires AutotuneCacheMode bindings") with config.enable_compilation_cache(False): compile_options = compiler.get_compile_options( diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d0b109dda07e..7c135b4ffeca 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -34,7 +34,6 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -254,7 +253,7 @@ def testIssue1213(self): @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] @@ -298,7 +297,7 @@ def check_left_eigenvectors(a, w, vl): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( @@ -317,7 +316,7 @@ def testEigvalsGrad(self, shape, dtype): # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -332,7 +331,7 @@ def testEigvalsGrad(self, shape, dtype): ) @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -344,7 +343,7 @@ def testEigvals(self, shape, dtype): @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -355,7 +354,7 @@ def testEigvalsInf(self): ) @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py index d2abb9fe3a0b..bf9c0fb6b51d 100644 --- a/tests/magma_linalg_test.py +++ b/tests/magma_linalg_test.py @@ -24,7 +24,6 @@ from jax._src import test_util as jtu from jax._src.lax import linalg as lax_linalg from jax._src.lib import gpu_solver -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -43,7 +42,7 @@ class MagmaLinalgTest(jtu.JaxTestCase): @jtu.run_on_devices("gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") @@ -94,7 +93,7 @@ def check_left_eigenvectors(a, w, vl): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") @@ -111,7 +110,7 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, self.assertTrue(np.all(np.isnan(result))) def testEigMagmaConfig(self): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") From 2c80d1af50ed580d2fb34bb45a471cce11679d99 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Nov 2024 10:57:07 -0500 Subject: [PATCH 053/112] Add a new API jax.lax.split. This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently. Before: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jnp.ones((3,)) In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr Out[3]: { lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let f:f32[5,3] = pjit[ name=unstack jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let l:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] k m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0 n:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] j o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0 p:f32[5,3] = add_any m o q:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] i r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0 s:f32[5,3] = add_any p r t:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] h u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0 v:f32[5,3] = add_any s u w:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] g x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0 y:f32[5,3] = add_any v x in (y,) } ] a b c d e in (f,) } ``` Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents. After: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jnp.ones((3,)) In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr Out[3]: { lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let f:f32[5,3] = pjit[ name=unstack jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let l:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] k m:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] j n:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] i o:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] h p:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] g q:f32[5,3] = concatenate[dimension=0] p o n m l in (q,) } ] a b c d e in (f,) } ``` --- CHANGELOG.md | 3 + docs/jax.lax.rst | 1 + jax/_src/lax/lax.py | 96 ++++++++++++++++++++++++++---- jax/_src/numpy/array_methods.py | 3 +- jax/_src/numpy/lax_numpy.py | 31 +++++----- jax/_src/pallas/mosaic/lowering.py | 21 +++++++ jax/experimental/jax2tf/jax2tf.py | 6 ++ jax/experimental/jet.py | 1 + jax/lax/__init__.py | 2 + tests/lax_autodiff_test.py | 18 ++++++ tests/lax_test.py | 27 +++++++++ tests/lax_vmap_test.py | 18 ++++++ 12 files changed, 197 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9082399c8695..a0901e87ccfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now supported on GPU. See {jax-issue}`#24663` for more details. + * Added {func}`jax.lax.split`. This is a primitive version of + {func}`jax.numpy.split`, added because it yields a more compact + transpose in automatic differentiation. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 065127718c54..d8a28bc399c8 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -154,6 +154,7 @@ Operators slice_in_dim sort sort_key_val + split sqrt square squeeze diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ff9ac0a49578..e97427445aef 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -654,6 +654,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: return concatenate_p.bind(*operands, dimension=dimension) +def split(operand: ArrayLike, sizes: Sequence[int], + axis: int = 0) -> Sequence[Array]: + """Splits an array along ``axis``. + + Args: + operand: an array to split + sizes: the sizes of the split arrays. The sum of the sizes must be equal + to the size of the ``axis`` dimension of ``operand``. + axis: the axis along which to split the array. + + Returns: + A sequence of ``len(sizes)`` arrays. If ``sizes`` is + ``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``, + taken along ``axis``. + """ + operand = asarray(operand) + return split_p.bind(operand, sizes=tuple(sizes), + axis=canonicalize_axis(axis, operand.ndim)) + + _precision_strings: dict[Any, Precision] = {} class Precision(enum.Enum): @@ -4373,18 +4393,8 @@ def _concatenate_transpose_rule(t, *operands, dimension): return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None for o in operands] else: - limit_points = np.cumsum( - [shape[dimension] for shape in operand_shapes]).tolist() - starts = np.zeros((len(operands), t.ndim), dtype=int).tolist() - limits = np.tile(t.shape, (len(operands), 1)).tolist() - - for i, s in enumerate(starts[1:]): - s[dimension] = limit_points[:-1][i] - for i, l in enumerate(limits): - l[dimension] = limit_points[i] - - return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o) - else None for o, start, limit in zip(operands, starts, limits)] + return split(t, tuple(shape[dimension] for shape in operand_shapes), + axis=dimension) def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) @@ -4413,6 +4423,68 @@ def _concatenate_lower(ctx, *xs, dimension): mlir.register_lowering(concatenate_p, _concatenate_lower) +def _split_shape_rule(operand, *, sizes, axis): + offset = 0 + shapes = [] + shape = list(operand.shape) + if any(s < 0 for s in sizes): + raise ValueError( + f"Sizes passed to split must be nonnegative, got {list(sizes)}") + if operand.shape[axis] != np.sum(sizes): + raise ValueError( + f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the " + f"operand shape {list(operand.shape)}") + for size in sizes: + shape[axis] = size + shapes.append(tuple(shape)) + return shapes + +def _split_dtype_rule(operand, *, sizes, axis): + return (operand.dtype,) * len(sizes) + +def _split_weak_type_rule(operand, *, sizes, axis): + return (operand.weak_type,) * len(sizes) + +def _split_transpose_rule(cotangents, operand, *, sizes, axis): + assert ad.is_undefined_primal(operand) + if all(type(t) is ad_util.Zero for t in cotangents): + return ad_util.Zero(operand.aval), + cotangents = [ + _zeros(t.aval) if type(t) is ad_util.Zero else t + for t in cotangents + ] + return concatenate(cotangents, dimension=axis), + +def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): + operand, = batched_args + bdim, = batch_dims + new_bdims = (bdim,) * len(sizes) + out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis) + return out, new_bdims + +def _split_lower(ctx, x, *, sizes, axis): + x_aval, = ctx.avals_in + start_indices = [0] * x_aval.ndim + limit_indices = list(x_aval.shape) + strides = (1,) * x_aval.ndim + outs = [] + for aval_out in ctx.avals_out: + limit_indices[axis] = start_indices[axis] + aval_out.shape[axis] + outs.append(mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, + limit_indices=limit_indices, strides=strides)) + start_indices[axis] = limit_indices[axis] + return outs + +split_p = core.Primitive('split') +split_p.multiple_results = True +split_p.def_abstract_eval( + partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, + _split_dtype_rule, _split_weak_type_rule)) +split_p.def_impl(partial(dispatch.apply_primitive, split_p)) +ad.deflinear2(split_p, _split_transpose_rule) +batching.primitive_batchers[split_p] = _split_batch_rule +mlir.register_lowering(split_p, _split_lower) + def _pad_dtype_rule(operand, padding_value, *, padding_config): if operand.dtype != padding_value.dtype: msg = "pad operand and padding_value must be same dtype: got {} and {}." diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 4768a8126c72..617213ca03de 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -629,7 +629,8 @@ def _multi_slice(self: Array, # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: - return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] + dims = (0,) + return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] def _chunk_iter(x, size): if size > x.shape[0]: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 898e4255dd8e..d256c97a9957 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) @@ -3280,10 +3280,10 @@ def _split(op: str, ary: ArrayLike, if (isinstance(indices_or_sections, (tuple, list)) or isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): - indices_or_sections = [ + split_indices = np.asarray([0] + [ core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1") - for i_s in indices_or_sections] - split_indices = [0] + list(indices_or_sections) + [size] + for i_s in indices_or_sections] + [size]) + sizes = list(np.diff(split_indices)) else: if core.is_symbolic_dim(indices_or_sections): raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is " @@ -3292,21 +3292,14 @@ def _split(op: str, ary: ArrayLike, f"in jax.numpy.{op} argument 1") part_size, r = divmod(size, num_sections) if r == 0: - split_indices = [i * part_size - for i in range(num_sections + 1)] + sizes = [part_size] * num_sections elif op == "array_split": - split_indices = ( - [i * (part_size + 1) for i in range(r + 1)] + - [i * part_size + ((r + 1) * (part_size + 1) - 1) - for i in range(num_sections - r)]) + sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") - split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] - for i in split_indices] - starts, ends = [0] * ndim(ary), shape(ary) - _subval = lambda x, i, v: subvals(x, [(i, v)]) - return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) - for start, end in zip(split_indices[:-1], split_indices[1:])] + sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] + for i in sizes] + return list(lax.split(ary, sizes, axis=axis)) @export @@ -4669,7 +4662,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: "Unstack requires arrays with rank > 0, however a scalar array was " "passed." ) - return tuple(moveaxis(x, axis, 0)) + dimensions = (axis,) + return tuple( + lax.squeeze(t, dimensions) + for t in lax.split(x, (1,) * x.shape[axis], axis=axis) + ) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index be4102dff716..f0286c156e45 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1871,6 +1871,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule +def _split_lowering_rule( + ctx: LoweringRuleContext, x, *, sizes, axis +): + (x_aval,) = ctx.avals_in + slice_size = np.array(x_aval.shape, dtype=np.int64) + starts = np.zeros_like(slice_size) + strides = np.ones_like(slice_size) + outs = [] + for size, aval_out in zip(sizes, ctx.avals_out): + slice_size[axis] = size + outs.append( + vector.extract_strided_slice( + aval_to_ir_type(aval_out), x, starts, slice_size, strides + ) + ) + starts[axis] += size + return outs + +lowering_rules[lax.split_p] = _split_lowering_rule + + def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): out_type = aval_to_ir_type(ctx.avals_out[0]) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c41eda693d7f..2cc670ef6a43 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension): tf_impl[lax.concatenate_p] = _concatenate +def _split(operand, *, sizes, axis): + return tf.split(operand, sizes, axis=axis) + +tf_impl[lax.split_p] = _split + + def _conv_general_dimension_numbers_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 2681ad1a2a7b..29ec21319361 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -323,6 +323,7 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.convert_element_type_p) deflinear(lax.broadcast_in_dim_p) deflinear(lax.concatenate_p) +deflinear(lax.split_p) deflinear(lax.pad_p) deflinear(lax.reshape_p) deflinear(lax.squeeze_p) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index d569ed641138..dc9c69d97795 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -203,6 +203,8 @@ sort as sort, sort_key_val as sort_key_val, sort_p as sort_p, + split as split, + split_p as split_p, sqrt as sqrt, sqrt_p as sqrt_p, square as square, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 78d90cb8a072..c7cbde069cc8 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -273,6 +273,24 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(base_shape)) + ], + num_pieces=range(3), + dtype=float_dtypes, + ) + def testSplitGrad(self, axis, base_shape, dtype, num_pieces): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + operands = (rng(shape, dtype),) + split = lambda x: lax.split(x, sizes, axis) + check_grads(split, operands, 2, ["fwd", "rev"], eps=1.) + + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides) for lhs_shape, rhs_shape, all_strides in itertools.chain( diff --git a/tests/lax_test.py b/tests/lax_test.py index 78bc5857acb7..48f70baa1e32 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -283,6 +283,33 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): numpy_op = lambda *args: lax_reference.concatenate(args, dim) self._CheckAgainstNumpy(numpy_op, op, args_maker) + @jtu.sample_product( + [dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(shape))], + num_pieces=range(3), + dtype=lax_test_util.default_dtypes, + ) + def testSplit(self, axis, base_shape, dtype, num_pieces): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + op = lambda x: lax.split(x, sizes, axis=axis) + def numpy_op(x): + return np.split(x, np.cumsum(sizes[:-1]), axis=axis) + self._CompileAndCheck(op, args_maker) + self._CheckAgainstNumpy(numpy_op, op, args_maker) + + def testSplitErrors(self): + with self.assertRaisesRegex(ValueError, + "Sizes passed to split must be nonnegative"): + lax.split(np.arange(5), [-1]) + with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"): + lax.split(np.arange(5), [6]) + with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"): + lax.split(np.arange(5), sizes=(), axis=1) + @jtu.sample_product( [ dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 83d4d657751b..49e06e17be15 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -344,6 +344,24 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims): op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis, bdims=bdims) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(base_shape)) + for bdims in lax_test_util.all_bdims(base_shape) + ], + num_pieces=range(3), + dtype=lax_test_util.default_dtypes, + ) + def testSplit(self, base_shape, dtype, num_pieces, axis, bdims): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + op = lambda x: lax.split(x, sizes, axis) + self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng, + multiple_results=True) + @jtu.sample_product( [dict(shape=shape, perm=perm, bdims=bdims) for shape, perm in [ From 1bf70fbbc42b28c4d929e70bc949347b9b5732ae Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 19 Nov 2024 13:01:53 -0800 Subject: [PATCH 054/112] [pallas:mosaic_gpu] `copy_gmem_to_smem` no longer requires `barrier` to be a keyword argument ... because there really isn't any reason to require that. PiperOrigin-RevId: 698116984 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 1 - .../pallas/ops/gpu/attention_mgpu.py | 10 +++++----- tests/pallas/mosaic_gpu_test.py | 18 +++++++++--------- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 91e1e1c45429..9b6adc86f981 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -63,7 +63,7 @@ def copy_in(self, slot, grid_indices, barrier_ref): gpu_primitives.copy_gmem_to_smem( self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands self.smem_ref.at[slot], - barrier=barrier_ref.at[slot], + barrier_ref.at[slot], ) def copy_out(self, slot, grid_indices, predicate=None): diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 5fc4ed5e7afc..36dcba5d15d0 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -218,7 +218,6 @@ def _copy_gmem_to_smem_lowering( def copy_gmem_to_smem( src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef, - *, barrier: pallas_core.AbstractMemoryRef, ) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 56db5379d5e2..1c5b4d9f741b 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -89,7 +89,7 @@ def _compute_wg(): plgpu.copy_gmem_to_smem( q_ref.at[pl.ds(q_seq_base, block_q), q_head], qo_smem, - barrier=q_barriers.at[wg_idx], + q_barriers.at[wg_idx], ) plgpu.barrier_wait(q_barriers.at[wg_idx]) @@ -166,17 +166,17 @@ def _memory_wg(): kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): s = (pl.ds(i * block_kv, block_kv), kv_head) - plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], barrier=k_barriers.at[i]) - plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], barrier=v_barriers.at[i]) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) s = (pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) - plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], barrier=k_barriers.at[tma_slot]) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) - plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], barrier=v_barriers.at[tma_slot]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) def kv_epilogue(i, _): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 83202937503d..fe52a33c1637 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -263,7 +263,7 @@ def test_copy_gmem_to_smem(self, indexer): ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.copy_gmem_to_smem( - x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref + x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref ) plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 @@ -284,7 +284,7 @@ def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.copy_gmem_to_smem( - x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer] + x_ref_gmem, scratch_ref, barrier_ref.at[indexer] ) plgpu.barrier_wait(barrier_ref.at[indexer]) o_ref[...] = scratch_ref[...] + 1 @@ -296,7 +296,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_copy_with_transforms(self, to_smem): def kernel(x_ref, o_ref, barrier_ref): if to_smem: - plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) else: plgpu.commit_smem() @@ -329,7 +329,7 @@ def test_scoped_copy_with_transforms(self): ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): - plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = tmp_ref[...] * 2 pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) @@ -351,7 +351,7 @@ def body(tmp_ref): def test_copy_with_transforms_and_indexing(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): - plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) @@ -379,7 +379,7 @@ def test_indexing_before_transpose(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( - x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier=barrier_ref + x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref ) plgpu.barrier_wait(barrier_ref) @@ -407,7 +407,7 @@ def test_copy_gmem_to_smem_in_run_scoped(self): def kernel(x_ref_gmem, o_ref): def body(barrier_ref): def inner_body(scratch_ref): - plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) @@ -1092,7 +1092,7 @@ def body(step, _): lambda: plgpu.copy_gmem_to_smem( x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)], x_smem.at[fetch_slot], - barrier=barrier.at[fetch_slot], + barrier.at[fetch_slot], ), lambda: None, ) @@ -1103,7 +1103,7 @@ def body(step, _): plgpu.copy_gmem_to_smem( x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)], x_smem.at[slot], - barrier=barrier.at[slot], + barrier.at[slot], ) jax.lax.fori_loop(0, num_steps, body, ()) From 0d36b0b433a93c707f86dac89b0c05d40302775a Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 19 Nov 2024 13:39:38 -0800 Subject: [PATCH 055/112] [Mosaic] Add target core type parameter to tpu.sem_signal Adds the optional core type parameter to `tpu.sem_signal` for cross-core signalling. If the target core type is not provided, the target core type is assumed to be that of the core issuing the signal. The issuing core type is determined based on the core type annotation of the parent function; if the annotation is not provided, the issuing core type is assumed to be TensorCore. PiperOrigin-RevId: 698129842 --- jaxlib/mosaic/BUILD | 1 + jaxlib/mosaic/dialect/tpu/tpu.td | 12 ++++++--- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 10 +++++++ jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 5 ++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 34 ++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 14f3ee13c0f5..da7498ed437d 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -62,6 +62,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4fd960063dc4..55d2e1ec975e 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -653,12 +653,18 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { MemRefOf<[TPU_SemaphoreType]>:$semaphore, I32:$amount, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + OptionalAttr:$core_type ); - let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? attr-dict `:` type($semaphore) +let assemblyFormat = [{ + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; + let builders = [ + // A backward-compatible builder that sets `core_type` to nullptr. + OpBuilder<(ins "Value":$semaphore, "Value":$amount, + "Value":$device_id, "Value":$core_id)>, + ]; } def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 10ab154b7c10..92e8953837e3 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "absl/hash/hash.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" @@ -81,6 +82,15 @@ void TPUDialect::initialize() { return mlir::cast(attr).getValue(); } +FailureOr> GetCoreTypeOfParentFunc(Operation &op) { + mlir::Operation *func_op = op.getParentOfType(); + if (func_op == nullptr) { + return op.emitError() << "Operation " << op.getName() + << " is not inside a func.func"; + } + return TPUDialect::GetCoreTypeAttr(func_op); +} + void VectorLayoutAttr::print(AsmPrinter &printer) const { printer << '<'; printer << getLayout(); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index dbb2ddaa5853..a8569acc6239 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -94,6 +95,10 @@ std::unique_ptr> createDebugAssertInsertionPass(); #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +// Determine the core type of the given op based on the `tpu.core_type` +// annotation of its parent function. +FailureOr> GetCoreTypeOfParentFunc(Operation &op); + // Changes the memory space of the value and propagates it through the program. LogicalResult specializeMemorySpace(TypedValue value, MemorySpace memory_space); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 96b78c8caf37..b4dcca66f7dc 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -28,9 +28,12 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/IRMapping.h" +#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -837,11 +840,42 @@ LogicalResult GetBarrierSemaphoreOp::verify() { return success(); } +void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, + Value semaphore, Value amount, Value device_id, + Value core_id) { + build(builder, state, semaphore, amount, device_id, core_id, + /*core_type=*/nullptr); +} + LogicalResult SemaphoreSignalOp::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { return emitOpError("Semaphore reference must be rank 0"); } + + FailureOr> issuing_core_type_maybe = + GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core_type_maybe)) { + return issuing_core_type_maybe; + } + CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); + CoreType target_core_type = getCoreType().value_or(issuing_core_type); + + if (getCoreId() == nullptr && getDeviceId() == nullptr) { + if (target_core_type != issuing_core_type) { + return emitOpError( + absl::StrFormat("Target core type (%s) must match source core type " + "(%s) when device_id and core_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); + } + } + if ((issuing_core_type == CoreType::kTc && + target_core_type == CoreType::kScScalarSubcore) || + (issuing_core_type == CoreType::kScScalarSubcore && + target_core_type == CoreType::kTc)) { + return emitOpError("Signalling between TC and SC is not implemented"); + } return success(); } From 3161a28424995d231b56c38ac43b89c0807d683a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 19 Nov 2024 14:00:40 -0800 Subject: [PATCH 056/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/229f376e046b9a51039dc1566d1e388ee7c1ca6d. PiperOrigin-RevId: 698136955 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 71fb2a8e9757..99c2af75f3ad 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "082a7014706f67bb8a42fb1c90051bc4990f2fd3" -XLA_SHA256 = "f1ca797df8e95bf13419d20520d2b783f075d80d1c5ddf1506ba427c934de849" +XLA_COMMIT = "229f376e046b9a51039dc1566d1e388ee7c1ca6d" +XLA_SHA256 = "895b39b5cb298460185f29df3ecc8882f4ee151b0f7dc93e5387ef81ea32e374" def repo(): tf_http_archive( From 42fbd301fc7bed57386423722c1a2ddae11f91ec Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Tue, 19 Nov 2024 14:18:32 -0800 Subject: [PATCH 057/112] Move JAX example to public XLA:CPU API PiperOrigin-RevId: 698143471 --- examples/jax_cpp/BUILD | 7 ++++++- examples/jax_cpp/main.cc | 12 ++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index 6e4647b5e491..b3cb995aae21 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -26,8 +26,13 @@ cc_binary( "@tsl//tsl/platform:platform_port", "@xla//xla:literal", "@xla//xla:literal_util", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt/cpu:cpu_client", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/service:hlo_module_config", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 2a8f8d4debba..ceac2cd2d7c9 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -36,15 +36,21 @@ limitations under the License. // } // ) +#include #include #include #include #include "third_party/absl/status/statusor.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" @@ -66,8 +72,10 @@ int main(int argc, char** argv) { // Run it using JAX C++ Runtime (PJRT). // Get a CPU client. + xla::CpuClientOptions options; + options.asynchronous = true; std::unique_ptr client = - xla::GetTfrtCpuClient(/*asynchronous=*/true).value(); + xla::GetXlaPjrtCpuClient(options).value(); // Compile XlaComputation to PjRtExecutable. xla::XlaComputation xla_computation(test_module_proto); From 525b646c0ebd5205f4fa0639c94adb2de47e1cf0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Nov 2024 14:46:44 -0800 Subject: [PATCH 058/112] Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a PiperOrigin-RevId: 698152759 --- CHANGELOG.md | 3 - docs/jax.lax.rst | 1 - jax/_src/lax/lax.py | 96 ++++-------------------------- jax/_src/numpy/array_methods.py | 3 +- jax/_src/numpy/lax_numpy.py | 31 +++++----- jax/_src/pallas/mosaic/lowering.py | 21 ------- jax/experimental/jax2tf/jax2tf.py | 6 -- jax/experimental/jet.py | 1 - jax/lax/__init__.py | 2 - tests/lax_autodiff_test.py | 18 ------ tests/lax_test.py | 27 --------- tests/lax_vmap_test.py | 18 ------ 12 files changed, 30 insertions(+), 197 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0901e87ccfc..9082399c8695 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,9 +59,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now supported on GPU. See {jax-issue}`#24663` for more details. - * Added {func}`jax.lax.split`. This is a primitive version of - {func}`jax.numpy.split`, added because it yields a more compact - transpose in automatic differentiation. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index d8a28bc399c8..065127718c54 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -154,7 +154,6 @@ Operators slice_in_dim sort sort_key_val - split sqrt square squeeze diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e97427445aef..ff9ac0a49578 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -654,26 +654,6 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: return concatenate_p.bind(*operands, dimension=dimension) -def split(operand: ArrayLike, sizes: Sequence[int], - axis: int = 0) -> Sequence[Array]: - """Splits an array along ``axis``. - - Args: - operand: an array to split - sizes: the sizes of the split arrays. The sum of the sizes must be equal - to the size of the ``axis`` dimension of ``operand``. - axis: the axis along which to split the array. - - Returns: - A sequence of ``len(sizes)`` arrays. If ``sizes`` is - ``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``, - taken along ``axis``. - """ - operand = asarray(operand) - return split_p.bind(operand, sizes=tuple(sizes), - axis=canonicalize_axis(axis, operand.ndim)) - - _precision_strings: dict[Any, Precision] = {} class Precision(enum.Enum): @@ -4393,8 +4373,18 @@ def _concatenate_transpose_rule(t, *operands, dimension): return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None for o in operands] else: - return split(t, tuple(shape[dimension] for shape in operand_shapes), - axis=dimension) + limit_points = np.cumsum( + [shape[dimension] for shape in operand_shapes]).tolist() + starts = np.zeros((len(operands), t.ndim), dtype=int).tolist() + limits = np.tile(t.shape, (len(operands), 1)).tolist() + + for i, s in enumerate(starts[1:]): + s[dimension] = limit_points[:-1][i] + for i, l in enumerate(limits): + l[dimension] = limit_points[i] + + return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o) + else None for o, start, limit in zip(operands, starts, limits)] def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) @@ -4423,68 +4413,6 @@ def _concatenate_lower(ctx, *xs, dimension): mlir.register_lowering(concatenate_p, _concatenate_lower) -def _split_shape_rule(operand, *, sizes, axis): - offset = 0 - shapes = [] - shape = list(operand.shape) - if any(s < 0 for s in sizes): - raise ValueError( - f"Sizes passed to split must be nonnegative, got {list(sizes)}") - if operand.shape[axis] != np.sum(sizes): - raise ValueError( - f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the " - f"operand shape {list(operand.shape)}") - for size in sizes: - shape[axis] = size - shapes.append(tuple(shape)) - return shapes - -def _split_dtype_rule(operand, *, sizes, axis): - return (operand.dtype,) * len(sizes) - -def _split_weak_type_rule(operand, *, sizes, axis): - return (operand.weak_type,) * len(sizes) - -def _split_transpose_rule(cotangents, operand, *, sizes, axis): - assert ad.is_undefined_primal(operand) - if all(type(t) is ad_util.Zero for t in cotangents): - return ad_util.Zero(operand.aval), - cotangents = [ - _zeros(t.aval) if type(t) is ad_util.Zero else t - for t in cotangents - ] - return concatenate(cotangents, dimension=axis), - -def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): - operand, = batched_args - bdim, = batch_dims - new_bdims = (bdim,) * len(sizes) - out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis) - return out, new_bdims - -def _split_lower(ctx, x, *, sizes, axis): - x_aval, = ctx.avals_in - start_indices = [0] * x_aval.ndim - limit_indices = list(x_aval.shape) - strides = (1,) * x_aval.ndim - outs = [] - for aval_out in ctx.avals_out: - limit_indices[axis] = start_indices[axis] + aval_out.shape[axis] - outs.append(mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, - limit_indices=limit_indices, strides=strides)) - start_indices[axis] = limit_indices[axis] - return outs - -split_p = core.Primitive('split') -split_p.multiple_results = True -split_p.def_abstract_eval( - partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule)) -split_p.def_impl(partial(dispatch.apply_primitive, split_p)) -ad.deflinear2(split_p, _split_transpose_rule) -batching.primitive_batchers[split_p] = _split_batch_rule -mlir.register_lowering(split_p, _split_lower) - def _pad_dtype_rule(operand, padding_value, *, padding_config): if operand.dtype != padding_value.dtype: msg = "pad operand and padding_value must be same dtype: got {} and {}." diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 617213ca03de..4768a8126c72 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -629,8 +629,7 @@ def _multi_slice(self: Array, # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: - dims = (0,) - return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] + return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] def _chunk_iter(x, size): if size > x.shape[0]: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d256c97a9957..898e4255dd8e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, + ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) @@ -3280,10 +3280,10 @@ def _split(op: str, ary: ArrayLike, if (isinstance(indices_or_sections, (tuple, list)) or isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): - split_indices = np.asarray([0] + [ + indices_or_sections = [ core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1") - for i_s in indices_or_sections] + [size]) - sizes = list(np.diff(split_indices)) + for i_s in indices_or_sections] + split_indices = [0] + list(indices_or_sections) + [size] else: if core.is_symbolic_dim(indices_or_sections): raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is " @@ -3292,14 +3292,21 @@ def _split(op: str, ary: ArrayLike, f"in jax.numpy.{op} argument 1") part_size, r = divmod(size, num_sections) if r == 0: - sizes = [part_size] * num_sections + split_indices = [i * part_size + for i in range(num_sections + 1)] elif op == "array_split": - sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r) + split_indices = ( + [i * (part_size + 1) for i in range(r + 1)] + + [i * part_size + ((r + 1) * (part_size + 1) - 1) + for i in range(num_sections - r)]) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") - sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] - for i in sizes] - return list(lax.split(ary, sizes, axis=axis)) + split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] + for i in split_indices] + starts, ends = [0] * ndim(ary), shape(ary) + _subval = lambda x, i, v: subvals(x, [(i, v)]) + return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) + for start, end in zip(split_indices[:-1], split_indices[1:])] @export @@ -4662,11 +4669,7 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: "Unstack requires arrays with rank > 0, however a scalar array was " "passed." ) - dimensions = (axis,) - return tuple( - lax.squeeze(t, dimensions) - for t in lax.split(x, (1,) * x.shape[axis], axis=axis) - ) + return tuple(moveaxis(x, axis, 0)) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f0286c156e45..be4102dff716 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1871,27 +1871,6 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule -def _split_lowering_rule( - ctx: LoweringRuleContext, x, *, sizes, axis -): - (x_aval,) = ctx.avals_in - slice_size = np.array(x_aval.shape, dtype=np.int64) - starts = np.zeros_like(slice_size) - strides = np.ones_like(slice_size) - outs = [] - for size, aval_out in zip(sizes, ctx.avals_out): - slice_size[axis] = size - outs.append( - vector.extract_strided_slice( - aval_to_ir_type(aval_out), x, starts, slice_size, strides - ) - ) - starts[axis] += size - return outs - -lowering_rules[lax.split_p] = _split_lowering_rule - - def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): out_type = aval_to_ir_type(ctx.avals_out[0]) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 2cc670ef6a43..c41eda693d7f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2087,12 +2087,6 @@ def _concatenate(*operands, dimension): tf_impl[lax.concatenate_p] = _concatenate -def _split(operand, *, sizes, axis): - return tf.split(operand, sizes, axis=axis) - -tf_impl[lax.split_p] = _split - - def _conv_general_dimension_numbers_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 29ec21319361..2681ad1a2a7b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -323,7 +323,6 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.convert_element_type_p) deflinear(lax.broadcast_in_dim_p) deflinear(lax.concatenate_p) -deflinear(lax.split_p) deflinear(lax.pad_p) deflinear(lax.reshape_p) deflinear(lax.squeeze_p) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index dc9c69d97795..d569ed641138 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -203,8 +203,6 @@ sort as sort, sort_key_val as sort_key_val, sort_p as sort_p, - split as split, - split_p as split_p, sqrt as sqrt, sqrt_p as sqrt_p, square as square, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index c7cbde069cc8..78d90cb8a072 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -273,24 +273,6 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(len(base_shape)) - ], - num_pieces=range(3), - dtype=float_dtypes, - ) - def testSplitGrad(self, axis, base_shape, dtype, num_pieces): - sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) - shape = list(base_shape) - shape[axis] = np.sum(sizes) - rng = jtu.rand_default(self.rng()) - operands = (rng(shape, dtype),) - split = lambda x: lax.split(x, sizes, axis) - check_grads(split, operands, 2, ["fwd", "rev"], eps=1.) - - @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides) for lhs_shape, rhs_shape, all_strides in itertools.chain( diff --git a/tests/lax_test.py b/tests/lax_test.py index 48f70baa1e32..78bc5857acb7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -283,33 +283,6 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): numpy_op = lambda *args: lax_reference.concatenate(args, dim) self._CheckAgainstNumpy(numpy_op, op, args_maker) - @jtu.sample_product( - [dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(len(shape))], - num_pieces=range(3), - dtype=lax_test_util.default_dtypes, - ) - def testSplit(self, axis, base_shape, dtype, num_pieces): - sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) - shape = list(base_shape) - shape[axis] = np.sum(sizes) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - op = lambda x: lax.split(x, sizes, axis=axis) - def numpy_op(x): - return np.split(x, np.cumsum(sizes[:-1]), axis=axis) - self._CompileAndCheck(op, args_maker) - self._CheckAgainstNumpy(numpy_op, op, args_maker) - - def testSplitErrors(self): - with self.assertRaisesRegex(ValueError, - "Sizes passed to split must be nonnegative"): - lax.split(np.arange(5), [-1]) - with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"): - lax.split(np.arange(5), [6]) - with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"): - lax.split(np.arange(5), sizes=(), axis=1) - @jtu.sample_product( [ dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 49e06e17be15..83d4d657751b 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -344,24 +344,6 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims): op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis, bdims=bdims) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(len(base_shape)) - for bdims in lax_test_util.all_bdims(base_shape) - ], - num_pieces=range(3), - dtype=lax_test_util.default_dtypes, - ) - def testSplit(self, base_shape, dtype, num_pieces, axis, bdims): - sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) - shape = list(base_shape) - shape[axis] = np.sum(sizes) - rng = jtu.rand_default(self.rng()) - op = lambda x: lax.split(x, sizes, axis) - self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng, - multiple_results=True) - @jtu.sample_product( [dict(shape=shape, perm=perm, bdims=bdims) for shape, perm in [ From c04aec9d525dd2e767495e41b98e82dd79315f37 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 19 Nov 2024 15:22:27 -0800 Subject: [PATCH 059/112] [Mosaic] Extend tpu.sem_signal with subcore_id This change: - Bumps up the version of Mosaic to 4 in `serde.cc`. - Adds optional `subcore_id` parameter to `tpu.sem_signal` for signalling specific subcores. - Extends deserialization to correctly parse the older versions of Mosaic without the new parameter `subcore_id` of `tpu.sem_signal`. PiperOrigin-RevId: 698163836 --- jaxlib/mosaic/dialect/tpu/tpu.td | 5 ++- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 36 +++++++++++---- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 44 +++++++++++++------ 3 files changed, 61 insertions(+), 24 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 55d2e1ec975e..590c27ac2099 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -654,14 +654,15 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { I32:$amount, Optional:$device_id, // For remote DMAs Optional:$core_id, // For megacore + Optional:$subcore_id, // For the SC vector subcore OptionalAttr:$core_type ); let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`subcore_id` $subcore_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; let builders = [ - // A backward-compatible builder that sets `core_type` to nullptr. + // A backward-compatible builder that sets `subcore_id` and `core_type` to nullptr. OpBuilder<(ins "Value":$semaphore, "Value":$amount, "Value":$device_id, "Value":$core_id)>, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index b4dcca66f7dc..a103cda7dae2 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -844,7 +844,7 @@ void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, Value semaphore, Value amount, Value device_id, Value core_id) { build(builder, state, semaphore, amount, device_id, core_id, - /*core_type=*/nullptr); + /*subcore_id=*/nullptr, /*core_type=*/nullptr); } LogicalResult SemaphoreSignalOp::verify() { @@ -861,21 +861,39 @@ LogicalResult SemaphoreSignalOp::verify() { CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); CoreType target_core_type = getCoreType().value_or(issuing_core_type); - if (getCoreId() == nullptr && getDeviceId() == nullptr) { + if (getCoreId() == nullptr && getDeviceId() == nullptr && + getSubcoreId() == nullptr) { if (target_core_type != issuing_core_type) { - return emitOpError( - absl::StrFormat("Target core type (%s) must match source core type " - "(%s) when device_id and core_id are not specified", - stringifyCoreType(target_core_type), - stringifyCoreType(issuing_core_type))); + return emitOpError(absl::StrFormat( + "Target core type (%s) must match source core type " + "(%s) when device_id, core_id and subcore_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); } } + if (target_core_type == CoreType::kScVectorSubcore && + issuing_core_type != CoreType::kScVectorSubcore && + getSubcoreId() == nullptr) { + return emitOpError( + "Subcore ID must be specified for the SC vector subcore"); + } + if (target_core_type != CoreType::kScVectorSubcore && + getSubcoreId() != nullptr) { + return emitOpError( + "Subcore ID must be specified only for the SC vector subcore"); + } if ((issuing_core_type == CoreType::kTc && - target_core_type == CoreType::kScScalarSubcore) || - (issuing_core_type == CoreType::kScScalarSubcore && + (target_core_type == CoreType::kScScalarSubcore || + target_core_type == CoreType::kScVectorSubcore)) || + ((issuing_core_type == CoreType::kScScalarSubcore || + issuing_core_type == CoreType::kScVectorSubcore) && target_core_type == CoreType::kTc)) { return emitOpError("Signalling between TC and SC is not implemented"); } + if (target_core_type == CoreType::kScVectorSubcore && + (getCoreId() != nullptr || getDeviceId() != nullptr)) { + return emitOpError("Signalling remote SC vector subcores is not supported"); + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index fd68c9e6c95e..27a886ebeb7e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -15,19 +15,21 @@ limitations under the License. // We need to keep some extra headers for the code in tpu_passes.h.inc. +#include #include // IWYU pragma: keep #include #include #include +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "absl/strings/str_format.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" @@ -43,7 +45,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 3; +constexpr int kVersion = 4; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -86,21 +88,37 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) { LogicalResult semaphore_signal_rule(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. + // Added subcore_id in version 4. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. - op->setAttr(OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. - // Hardcoding that one optional value is device_id, not core_id. This - // could misinterpret sem_signals where core_id is specified, but - // device_id isn't. - op->setAttr(OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); - } else { - return op->emitError("Unexpected operand count in tpu.semaphore_signal"); + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0})); + } + return op->emitError("Unexpected operand count in tpu.semaphore_signal"); + } else if (version < 4) { + ArrayRef operand_segment_sizes = + op->getAttrOfType( + OpTrait::AttrSizedOperandSegments< + SemaphoreSignalOp>::getOperandSegmentSizeAttr()); + if (operand_segment_sizes.size() != 4) { + return op->emitError(absl::StrFormat( + "Expected operand count to be 4 in tpu.semaphore_signal. Got %d", + operand_segment_sizes.size())); } + SmallVector new_operand_segment_sizes( + operand_segment_sizes.begin(), operand_segment_sizes.end()); + new_operand_segment_sizes.push_back(0); + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), + new_operand_segment_sizes)); } return success(); } From 8c71d1ad6d543f95db1b191505150dd19b0b6e69 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 19 Nov 2024 18:30:57 -0800 Subject: [PATCH 060/112] Make deprecated jax.experimental.array_api module visibility internal-only This is in preparation for the module to be removed. PiperOrigin-RevId: 698215225 --- jax/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 0da99677dc7b..26694fec2ad3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1047,7 +1047,7 @@ pytype_library( "experimental/array_api/*.py", ], ), - visibility = [":internal"] + jax_visibility("array_api"), + visibility = [":internal"], deps = [ ":jax", ], From 867a36189bf6c9d19f0f4a6522e91306dec5945f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Nov 2024 18:59:08 -0800 Subject: [PATCH 061/112] Fix a bug where constant deduplication used an inappropriate inequality. We need to compare constants for bitwise equality, not, e.g., floating point equality. The change that added deduplication caused us to conflate +0.0 and -0.0, which led a downstream test not to terminate. PiperOrigin-RevId: 698221147 --- jax/_src/interpreters/mlir.py | 34 ++++++++++++++++++++-------------- tests/pjit_test.py | 11 +++++++++++ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 23325a9d7e26..102e4f490b5c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1755,33 +1755,39 @@ def _emit_lowering_rule_as_fun(lowering_rule, class HashableLiteral: """Hashable wrapper of core.Literal, used for deduplicating IR constants.""" - __slots__ = ["value"] + __slots__ = ["value", "data"] value: core.Literal + # Copy of the value suitable for an equality comparison. We are careful to + # avoid floating point comparisons here, because in particular we don't want + # 0.0 and -0.0 to be considered equal, but we are fine with NaNs being equal. + data: bytes | int | bool | None + def __init__(self, value): self.value = value + if isinstance(value.val, (np.generic, np.ndarray)): + self.data = value.val.tobytes() + elif isinstance(value.val, (bool, int)): + self.data = value.val + elif isinstance(value.val, float): + self.data = np.float64(value.val).tobytes() + elif isinstance(value.val, complex): + self.data = np.complex128(value.val).tobytes() + else: + self.data = None # Unhandled case. def __hash__(self): - h = self.value.hash - return id(self.value.val) if h is None else h + return hash(self.data) def __eq__(self, other): - if self is other: - return True if type(self.value.val) != type(other.value.val): return False if self.value.aval != other.value.aval: return False - if isinstance(self.value.val, (bool, int, float, complex)): - return self.value == other.value - if isinstance(self.value.val, (np.generic, np.ndarray)): - return np.array_equal( - self.value.val, other.value.val, - equal_nan=np.issubdtype(self.value.val.dtype, np.inexact)) - # Since the use case is constant deduplication, it's safe to return - # False in unhandled cases. - return False + if self.data is None: + return id(self) == id(other) + return self.data == other.data def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6df011419513..e32424cfdded 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1312,6 +1312,17 @@ def under_jvp(f): ans2 = jnp.sin(x0) # cpp_pjit hit with bad cache entry assert(ans1.devices() == ans2.devices()) + def test_zero_literal_equality(self): + # This test verifies that we don't accidentally conflate positive and + # negative zeros when deduplicating literals in the IR. + f = jax.jit(lambda x: (x / np.float32(-0.0), x / np.float32(0.0))) + a, b = f(np.float32(1.0)) + self.assertEqual(a, -np.inf) + self.assertEqual(b, np.inf) + ir = f.lower(np.float32(1.0)).as_text() + self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) + self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): From 6c291d67b7a9dfbc0517c0ab7828e80dc88bdc01 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 19 Nov 2024 19:03:55 -0800 Subject: [PATCH 062/112] [Mosaic] Add `tpu.log` verification on SC Guards against using formatting and targeting vector subcores on SC. PiperOrigin-RevId: 698222100 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 590c27ac2099..de5e3514fc1d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -761,6 +761,7 @@ def TPU_LogOp : TPU_Op<"log"> { ); let results = (outs); let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; + let hasVerifier = 1; } def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index a103cda7dae2..8586e2a16c8a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1053,6 +1053,30 @@ LogicalResult ConcatenateOp::verify() { return success(); } +LogicalResult LogOp::verify() { + FailureOr> logging_core_type_maybe = + GetCoreTypeOfParentFunc(**this); + if (failed(logging_core_type_maybe)) { + return failure(); + } + CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc); + if ((logging_core_type == CoreType::kScScalarSubcore || + logging_core_type == CoreType::kScVectorSubcore) && + getFormattedAttr() != nullptr && getFormattedAttr().getValue()) { + return emitOpError("Formatted logging is not supported on SC"); + } + switch (logging_core_type) { + case CoreType::kTc: + case CoreType::kScScalarSubcore: + return success(); + case CoreType::kScVectorSubcore: + return emitOpError("Log op is not supported on the SC vector subcore"); + } + return emitOpError( + absl::StrFormat("Unexpected core type: %s", + stringifyCoreType(logging_core_type_maybe->value()))); +} + } // namespace tpu } // namespace mlir From 4bb81075bcc8c5ac8ea9d5993f9d877bb16f3a13 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 29 Oct 2024 12:46:08 -0700 Subject: [PATCH 063/112] represent `random.key_impl` of builtin RNGs by canonical string name We do not have great reason to return specs here, and sticking to strings instead can help with simple serialization. --- jax/_src/random.py | 9 +++++---- tests/extend_test.py | 30 ++++++++++++++++++------------ tests/random_test.py | 4 ++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index dc9fc18aff38..6c04b0620080 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -293,14 +293,15 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: return _return_prng_keys(wrapped, _split(typed_key, num)) -def _key_impl(keys: KeyArray) -> PRNGImpl: +def _key_impl(keys: KeyArray) -> str | PRNGSpec: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) - return keys_dtype._impl + impl = keys_dtype._impl + return impl.name if impl.name in prng.prngs else PRNGSpec(impl) -def key_impl(keys: KeyArrayLike) -> PRNGSpec: +def key_impl(keys: KeyArrayLike) -> str | PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) - return PRNGSpec(_key_impl(typed_keys)) + return _key_impl(typed_keys) def _key_data(keys: KeyArray) -> Array: diff --git a/tests/extend_test.py b/tests/extend_test.py index 84a907c7331d..42196a940a76 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -70,35 +70,41 @@ def test_symbols(self): class RandomTest(jtu.JaxTestCase): - def test_key_make_with_custom_impl(self): - shape = (4, 2, 7) - + def make_custom_impl(self, shape, seed=False, split=False, fold_in=False, + random_bits=False): + assert not split and not fold_in and not random_bits # not yet implemented def seed_rule(_): return jnp.ones(shape, dtype=jnp.dtype('uint32')) def no_rule(*args, **kwargs): assert False, 'unreachable' - impl = jex.random.define_prng_impl( - key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule, - random_bits=no_rule) + return jex.random.define_prng_impl( + key_shape=shape, seed=seed_rule if seed else no_rule, split=no_rule, + fold_in=no_rule, random_bits=no_rule) + + def test_key_make_with_custom_impl(self): + impl = self.make_custom_impl(shape=(4, 2, 7), seed=True) k = jax.random.key(42, impl=impl) self.assertEqual(k.shape, ()) self.assertEqual(impl, jax.random.key_impl(k)) def test_key_wrap_with_custom_impl(self): - def no_rule(*args, **kwargs): - assert False, 'unreachable' - shape = (4, 2, 7) - impl = jex.random.define_prng_impl( - key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule, - random_bits=no_rule) + impl = self.make_custom_impl(shape=shape) data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32')) k = jax.random.wrap_key_data(data, impl=impl) self.assertEqual(k.shape, (3,)) self.assertEqual(impl, jax.random.key_impl(k)) + def test_key_impl_is_spec(self): + # this is counterpart to random_test.py: + # KeyArrayTest.test_key_impl_builtin_is_string_name + spec_ref = self.make_custom_impl(shape=(4, 2, 7), seed=True) + key = jax.random.key(42, impl=spec_ref) + spec = jax.random.key_impl(key) + self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})") + class FfiTest(jtu.JaxTestCase): diff --git a/tests/random_test.py b/tests/random_test.py index fed12792d5c6..f9167b22b4ea 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1125,10 +1125,10 @@ class A: pass jax.random.key(42, impl=A()) @jtu.sample_product(name=[name for name, _ in PRNG_IMPLS]) - def test_key_spec_repr(self, name): + def test_key_impl_builtin_is_string_name(self, name): key = jax.random.key(42, impl=name) spec = jax.random.key_impl(key) - self.assertEqual(repr(spec), f"PRNGSpec({name!r})") + self.assertEqual(spec, name) def test_keyarray_custom_vjp(self): # Regression test for https://github.com/jax-ml/jax/issues/18442 From 4d60db17413208cf4ff829242d21eaa46c9586c4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 19 Nov 2024 21:32:44 -0800 Subject: [PATCH 064/112] Add test_compute_on_host_shared_sharding in memories_test PiperOrigin-RevId: 698250352 --- tests/memories_test.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/memories_test.py b/tests/memories_test.py index da4239338c02..ca676a2b1993 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -808,6 +808,46 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_host_shared_sharding(self): + mesh = jtu.create_mesh((2,), ("x")) + device_sharding = NamedSharding(mesh, P("x")) + host_sharding = device_sharding.with_memory_kind("pinned_host") + + @compute_on("device_host") + @functools.partial( + jax.jit, + in_shardings=(host_sharding, device_sharding), + out_shardings=(host_sharding, device_sharding), + donate_argnums=(0, 1), + ) + def host_func(x, y): + return (x * y), ((x**2) * (y**2)) + + @functools.partial( + jax.jit, + in_shardings=(host_sharding, device_sharding), + out_shardings=(host_sharding, device_sharding), + donate_argnums=(0), + ) + def device_func(host_data, device_data): + host_data, device_data = host_func(host_data, device_data) + device_data = device_data * 2 + host_data, device_data = host_func(host_data, device_data) + return (host_data, device_data) + + input_x = jnp.ones(8) + input_host = jax.device_put(input_x, host_sharding) + + input_device = jnp.arange(8) + input_device = jnp.where(input_device < 4, 0, 1) + input_device = jax.device_put(input_device, device_sharding) + + output_host, output_device = device_func(input_host, input_device) + self.assertEqual(output_host.sharding.memory_kind, 'pinned_host') + self.assertEqual(output_device.sharding.memory_kind, 'device') + self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.]) + self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.]) + def test_compute_on_basic_inline(self): @compute_on('device_host') @jax.jit From 1afb05e2e2341362a9107a6726721f4f617db46c Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Wed, 20 Nov 2024 03:01:11 -0800 Subject: [PATCH 065/112] [mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise. Only propagate signedness from operands when the output type of `op` is an `ir.IntegerType`. PiperOrigin-RevId: 698324596 --- jax/experimental/mosaic/gpu/fragmented_array.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index fd989d052917..e45202386b47 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -623,10 +623,6 @@ def to_layout(self, new_layout: FragmentedLayout): ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): - is_signed = ( - output_is_signed if output_is_signed is not None else self.is_signed - ) - other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): @@ -636,7 +632,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): raise NotImplementedError(o) o = FragmentedArray.splat( - o, shape=self.shape, layout=self.layout, is_signed=is_signed + o, shape=self.shape, layout=self.layout, is_signed=self.is_signed ) if isinstance(o.layout, WGSplatFragLayout): @@ -646,7 +642,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): o.registers.flat[0], shape=self.shape, layout=self.layout, - is_signed=is_signed, + is_signed=self.is_signed, ) else: if self.layout != o.layout: @@ -659,8 +655,13 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) + reg_ty = new_regs.flat[0].type + if ir.VectorType.isinstance(reg_ty): + reg_ty = ir.VectorType(reg_ty).element_type + if output_is_signed is None and ir.IntegerType.isinstance(reg_ty): + output_is_signed = self.is_signed return FragmentedArray( - _registers=new_regs, _layout=self.layout, _is_signed=is_signed + _registers=new_regs, _layout=self.layout, _is_signed=output_is_signed ) def __pos__(self): From 14da7ebb76d5a97b9955822e8781d1f45505cb9e Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Wed, 20 Nov 2024 03:40:40 -0800 Subject: [PATCH 066/112] [pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type. Only handles the case where operand type and target type have the same bitwidth. PiperOrigin-RevId: 698332564 --- jax/_src/pallas/mosaic_gpu/lowering.py | 25 +++++++++++++++++++ .../mosaic/gpu/fragmented_array.py | 8 ++++-- tests/pallas/mosaic_gpu_test.py | 24 ++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6d30cdb0d4a3..5b5a6f4ace83 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1501,6 +1501,31 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): return list(switch_op.results) +@register_lowering_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, operand, *, new_dtype +): + # TODO(petebu) Handle case where src and dst types have different bitwidths + [operand_aval] = ctx.avals_in + operand = _ensure_fa(operand, operand_aval.dtype) + src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype) + dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype) + assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) + assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) + if src_elem_type.width != dst_elem_type.width: + raise NotImplementedError( + f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they" + " have different widths" + ) + if ir.IntegerType.isinstance(dst_elem_type): + output_is_signed = mgpu_utils.is_signed(new_dtype) + else: + output_is_signed = None + return mgpu.FragmentedArray.bitcast( + operand, dst_elem_type, output_is_signed=output_is_signed + ) + + def _bcast( x: ir.Value, y: ir.Value, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index e45202386b47..2b985ff5c9b8 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -929,7 +929,9 @@ def fast_instr(x): raise NotImplementedError(x.type) return fast_instr - def bitcast(self, elt: ir.Type): + def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): + if elt == self.mlir_dtype: + return self reg_type = self.registers.flat[0].type if ir.VectorType.isinstance(reg_type): reg_shape = ir.VectorType(reg_type).shape @@ -937,7 +939,9 @@ def bitcast(self, elt: ir.Type): else: ty = elt - return self._pointwise(lambda x: arith.bitcast(ty, x)) + return self._pointwise( + lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed + ) def __getitem__(self, idx): if self.layout != WGMMA_LAYOUT: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index fe52a33c1637..b8098f40eccf 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1052,6 +1052,30 @@ def kernel(x_ref, o_ref): self.assertEqual(data.count('"name": "store"'), 2) np.testing.assert_array_equal(y, x + x) + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + m, n = 16, 8 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + class PipelineTest(PallasTest): From c76e5fe9a0d1c4b67fdc844a824f1bd53821653d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 04:28:56 -0800 Subject: [PATCH 067/112] [pallas:mosaic_gpu] `copy_smem_to_gmem` now supports `wait_read_only` PiperOrigin-RevId: 698343812 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 4 +++- jax/_src/pallas/mosaic_gpu/primitives.py | 24 +++++++++++++++++------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 9b6adc86f981..90c00765e8b1 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -207,7 +207,9 @@ def loop_body(step, carry): # Wait for the current GMEM->SMEM copy to complete. gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. - gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) + gpu_primitives.wait_smem_to_gmem( + max_concurrent_steps - 1, wait_read_only=True + ) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): body( diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 36dcba5d15d0..0f25f9808ac1 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -363,20 +363,30 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: @wait_smem_to_gmem_p.def_effectful_abstract_eval -def _wait_smem_to_gmem_abstract_eval(n): - del n # Unused. +def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): + del n, wait_read_only # Unused. return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(wait_smem_to_gmem_p) -def _wait_smem_to_gmem_lowering(ctx: lowering.LoweringRuleContext, n): - ctx.launch_ctx.await_async_copy(allow_groups=n) +def _wait_smem_to_gmem_lowering( + ctx: lowering.LoweringRuleContext, n, *, wait_read_only +): + ctx.launch_ctx.await_async_copy( + allow_groups=n, await_read_only=wait_read_only + ) return () -def wait_smem_to_gmem(n: int) -> None: - """Waits until there are no more than ``n`` SMEM->GMEM copies in flight.""" - wait_smem_to_gmem_p.bind(n) +def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None: + """Waits until there are no more than ``n`` SMEM->GMEM copies in flight. + + Args: + n: The maximum number of copies in flight to wait for. + wait_read_only: If ``True``, wait for the in flight copies to finish + reading from SMEM. The writes to GMEM are not waited for. + """ + wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only) # WGMMA on an accumulator reference From f442d40f926f801135cea7637a64cce47f05eae1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 04:29:12 -0800 Subject: [PATCH 068/112] [mosaic_gpu] Fixed `FragmentedArray` comparisons with literals PiperOrigin-RevId: 698343858 --- tests/mosaic/gpu_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 157f682f5eef..ab2a00c730d6 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1318,18 +1318,21 @@ def kernel(ctx, dst, _): operator.ne, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], + rhs_is_literal=[False, True] ) - def test_comparison(self, op, dtype, m=64, n=32): + def test_comparison(self, op, dtype, rhs_is_literal, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + rhs = 0 if rhs_is_literal else iota + 1 + op(iota, rhs).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - np.testing.assert_array_equal(result, op(iota, iota + 1)) + rhs = rhs = 0 if rhs_is_literal else iota + 1 + np.testing.assert_array_equal(result, op(iota, rhs)) @parameterized.product( op=[operator.and_, operator.or_, operator.xor], From 04e4c69f7f72e3aabee726315f370c8182045b49 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 05:05:14 -0800 Subject: [PATCH 069/112] [mosaic_gpu] Handle older `jaxlib`s in the profiler module `measure` now raises a `RuntimeError` if the available `jaxlib` does not have the required custom calls. PiperOrigin-RevId: 698351662 --- jax/experimental/mosaic/gpu/profiler.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 337581c54b86..0594e9239be7 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -36,12 +36,15 @@ try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: - pass + has_registrations = False else: - for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): - xla_client.register_custom_call_target( - name, handler, platform="CUDA", api_version=1 - ) + # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36. + has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations") + if has_registrations: + for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): + xla_client.register_custom_call_target( + name, handler, platform="CUDA", api_version=1 + ) # ruff: noqa: F405 # mypy: ignore-errors @@ -80,6 +83,11 @@ def measure( Returns: The return value of ``f`` and the elapsed time in milliseconds. """ + if not has_registrations: + raise RuntimeError( + "This function requires jaxlib >=0.4.36 with CUDA support." + ) + if not (args or kwargs): # We require at least one argument and at least one output to ensure # that there is a data dependency between `_event_record` calls in From 1df4b5f79885e0d9fb0d8e097b7a526e577f04ef Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 05:06:50 -0800 Subject: [PATCH 070/112] [pallas] Do not skip vmap tests on GPU when x64 is enabled PiperOrigin-RevId: 698351984 --- tests/pallas/pallas_vmap_test.py | 37 +++++++++++++++----------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index fefccfe7eb4f..ffa6195625dd 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -22,6 +22,7 @@ import jax from jax import random from jax._src import config +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl @@ -35,6 +36,10 @@ config.parse_flags_with_absl() +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -42,8 +47,6 @@ class PallasBaseTest(jtu.JaxTestCase): def setUp(self): if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled: - self.skipTest("On GPU the test works only in 32-bit") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") @@ -67,7 +70,7 @@ def setUp(self): def test_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -77,7 +80,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_simple_kernel_with_in_axes_None(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add(x_ref, y_ref, o_ref): o_ref[()] = x_ref[()] + y_ref[()] @@ -87,7 +90,7 @@ def add(x_ref, y_ref, o_ref): def test_double_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -97,7 +100,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -108,7 +111,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_batched_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), intx), grid=(7,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -120,7 +123,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_slicing_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -151,7 +154,7 @@ def kernel(src, dst): def test_vmap_of_kernel_with_input_output_aliases(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), input_output_aliases={1:0}, grid=()) def add(x_ref, _, o_ref): @@ -163,7 +166,7 @@ def add(x_ref, _, o_ref): def test_vmap_of_kernel_with_input_output_aliases_different_axes(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((4,), intx), input_output_aliases={0: 0}, grid=(), ) @@ -176,7 +179,7 @@ def add(x_ref, o_ref): def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -194,7 +197,7 @@ def add_one(x_ref, o_ref): def test_double_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx), grid=(4,)) def sin(x_ref, o_ref): i = pl.program_id(0) @@ -211,7 +214,7 @@ def sin(x_ref, o_ref): def test_small_large_vmap(self): # Catches https://github.com/jax-ml/jax/issues/18361 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -230,7 +233,7 @@ def add_one(x_ref, o_ref): def test_small_small_large_vmap(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -249,12 +252,6 @@ def add_one(x_ref, o_ref): class PallasCallVmapInterpretTest(PallasCallVmapTest): INTERPRET = True - def setUp(self): - super().setUp() - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") - if __name__ == "__main__": absltest.main() From a582df02971337dba2834c5a3953f4af067caaa0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 20 Nov 2024 06:38:33 -0800 Subject: [PATCH 071/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fcee07f619a765db815d9ed4e2bc229275818a2b. PiperOrigin-RevId: 698371906 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 99c2af75f3ad..a554cfd03687 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "229f376e046b9a51039dc1566d1e388ee7c1ca6d" -XLA_SHA256 = "895b39b5cb298460185f29df3ecc8882f4ee151b0f7dc93e5387ef81ea32e374" +XLA_COMMIT = "fcee07f619a765db815d9ed4e2bc229275818a2b" +XLA_SHA256 = "1dd144e64e2c2dcc20a2130e10607fec7b3a810926ba912918dd5437698a3375" def repo(): tf_http_archive( From a4266b5e31853a62a06281b211024cb8c2581876 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 08:23:19 -0800 Subject: [PATCH 072/112] Mention python 3.13 in docs & package metadata --- docs/deprecation.md | 3 +++ jaxlib/setup.py | 1 + setup.py | 1 + 3 files changed, 5 insertions(+) diff --git a/docs/deprecation.md b/docs/deprecation.md index 385d31271421..603a027f5efc 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -18,6 +18,7 @@ This means we support at least: * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. * **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**. + * **Python 3.13** was released October 2024, and will be supported in new JAX releases at least until **July 2028**. * All NumPy feature releases in the 24 months prior to each JAX release. For example: @@ -25,6 +26,7 @@ This means we support at least: * **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025** * **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025** * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026** + * **NumPy 2.1** was released August 2024, and will be supported in new JAX releases at least until **August 2026** * All SciPy feature releases in the 24 months prior to each JAX release. For example: @@ -32,6 +34,7 @@ This means we support at least: * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. * **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**. * **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. + * **Scipy 1.14** was released June 2024, and will be supported in new JAX releases at least until **June 2026**. JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed diff --git a/jaxlib/setup.py b/jaxlib/setup.py index dea9503c7c00..989a8314eb92 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -72,6 +72,7 @@ def has_ext_modules(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ 'jaxlib': [ diff --git a/setup.py b/setup.py index 98d509375d62..a3b54f7aa94f 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def load_version_module(pkg_path): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], zip_safe=False, ) From 1e9e85a39eee20f7362c7aa6e79a8f345bbef748 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 20 Nov 2024 08:26:12 -0800 Subject: [PATCH 073/112] Simplify handling of `DotAlgorithmPreset` output types. Create a clear distinction between the type used for accumulation and possible output types. PiperOrigin-RevId: 698399447 --- jax/_src/lax/lax.py | 84 ++++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ff9ac0a49578..39c5bca5819c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -879,11 +879,11 @@ def __str__(self) -> str: def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: match self: case ( - DotAlgorithmPreset.DEFAULT | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.DEFAULT + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32: @@ -906,14 +906,26 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: return self.lhs_precision_type @property - def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: + def accumulation_type(self) -> DTypeLike | None: match self: case ( - DotAlgorithmPreset.DEFAULT | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.DEFAULT + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None + case DotAlgorithmPreset.F16_F16_F16: + return np.float16 + case DotAlgorithmPreset.BF16_BF16_BF16: + return dtypes.bfloat16 + case DotAlgorithmPreset.F64_F64_F64: + return np.float64 + case _: + return np.float32 + + @property + def supported_output_types(self) -> tuple[DTypeLike, ...] | None: + match self: case ( DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM @@ -921,16 +933,11 @@ def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz) - case DotAlgorithmPreset.F16_F16_F16: - return np.float16 case DotAlgorithmPreset.F16_F16_F32: return (np.float32, np.float16) - case DotAlgorithmPreset.BF16_BF16_BF16: - return dtypes.bfloat16 - case DotAlgorithmPreset.F64_F64_F64: - return np.float64 case _: - return np.float32 + accumulation_type = self.accumulation_type + return None if accumulation_type is None else (accumulation_type,) def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: @@ -941,16 +948,18 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, tf32 = ir.FloatTF32Type.get() match self: case ( - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), - np.dtype(dtypes.float8_e4m3fn), - np.dtype(dtypes.float8_e4m3fnuz), - np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] + fp8_dtypes = [ + np.dtype(dtypes.float8_e4m3b11fnuz), + np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e4m3fnuz), + np.dtype(dtypes.float8_e5m2), + np.dtype(dtypes.float8_e5m2fnuz), + ] if dtypes.float8_e3m4 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] if dtypes.float8_e4m3 is not None: @@ -958,13 +967,20 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " - f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.") + f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.' + ) lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype)) rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype)) acc = ir.F32Type.get() return hlo.DotAlgorithm.get( - lhs, rhs, acc, 1, 1, 1, - self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + lhs, + rhs, + acc, + 1, + 1, + 1, + self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, + ) case DotAlgorithmPreset.F16_F16_F16: return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) case DotAlgorithmPreset.F16_F16_F32: @@ -3649,9 +3665,8 @@ def maybe_convert_dtype(input_dtype, target_dtype): return input_dtype if not isinstance(target_dtype, tuple): target_dtype = (target_dtype,) - if any(input_dtype == d for d in target_dtype): - return input_dtype - return target_dtype[0] + return input_dtype if input_dtype in target_dtype else target_dtype[0] + if algorithm == DotAlgorithmPreset.BF16_BF16_F32: lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type) rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type) @@ -3662,10 +3677,15 @@ def maybe_convert_dtype(input_dtype, target_dtype): out_dtype = maybe_convert_dtype(out_dtype, np.float32) return lhs_dtype, rhs_dtype, out_dtype else: + if isinstance(algorithm, DotAlgorithmPreset): + supported_output_types = algorithm.supported_output_types + else: + supported_output_types = (algorithm.accumulation_type,) + return ( maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type), maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type), - maybe_convert_dtype(out_dtype, algorithm.accumulation_type), + maybe_convert_dtype(out_dtype, supported_output_types), ) From 85e2969aea15141bedd6d4ec0548cc02ef45b069 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 08:48:26 -0800 Subject: [PATCH 074/112] Deprecate several private APIs in jax.lib --- CHANGELOG.md | 6 ++++++ jax/lib/xla_client.py | 7 ++++++- jax/lib/xla_extension.py | 25 +++++++++++++++++++++++-- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9082399c8695..37fd68bce39d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. result in an indexing overflow for batch sizes close to int32 max. See {jax-issue}`#24843` for more details. +* Deprecations + * `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated; + use `jax.Array` instead. + * `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError` + instead. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index aaf3791037d0..cd3696d8838c 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -18,7 +18,6 @@ get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version -ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions DeviceAssignment = _xc.DeviceAssignment @@ -95,6 +94,11 @@ "XlaComputation is deprecated; use StableHLO instead.", _xc.XlaComputation, ), + # Added Nov 20 2024 + "ArrayImpl": ( + "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.", + _xc.ArrayImpl, + ), } import typing as _typing @@ -106,6 +110,7 @@ ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target shape_from_pyval = _xc.shape_from_pyval + ArrayImpl = _xc.ArrayImpl Device = _xc.Device FftType = _FftType PaddingType = _xc.PaddingType diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 20ce459685aa..52fe94e231d1 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -24,7 +24,6 @@ pmap_lib = _xe.pmap_lib profiler = _xe.profiler pytree = _xe.pytree -ArrayImpl = _xe.ArrayImpl Device = _xe.Device DistributedRuntimeClient = _xe.DistributedRuntimeClient HloModule = _xe.HloModule @@ -33,6 +32,28 @@ PjitFunctionCache = _xe.PjitFunctionCache PjitFunction = _xe.PjitFunction PmapFunction = _xe.PmapFunction -XlaRuntimeError = _xe.XlaRuntimeError +_deprecations = { + # Added Nov 20 2024 + "ArrayImpl": ( + "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", + _xe.ArrayImpl, + ), + "XlaRuntimeError": ( + "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", + _xe.XlaRuntimeError, + ), +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + ArrayImpl = _xe.ArrayImpl + XlaRuntimeError = _xe.XlaRuntimeError +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing del _xe From 62225926253474c6e5e4b202d5c9cf3363a02a03 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 20 Nov 2024 17:46:06 +0000 Subject: [PATCH 075/112] Fix KeyError recently introduced in cloud_tpu_init.py This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889 --- jax/_src/cloud_tpu_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 8ff52bd2f559..a2f137686dae 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']: + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''): os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' # this makes tensorstore serialization work better on TPU From 8d84f2837346b29b52d1b797f672af10df05df41 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 20 Nov 2024 09:59:39 -0800 Subject: [PATCH 076/112] [pallas mgpu] Lowering for while loops as long as they are secretly for loops. PiperOrigin-RevId: 698427307 --- jax/_src/pallas/mosaic_gpu/lowering.py | 38 ++++++++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 16 +++++++++++ 2 files changed, 54 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5b5a6f4ace83..66437839cce2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1473,6 +1473,44 @@ def _scan_lowering_rule( return for_out +@register_lowering_rule(lax.while_p) +def _while_lowering_rule( + ctx: LoweringRuleContext, + *args, + cond_jaxpr, + body_jaxpr, + cond_nconsts, + body_nconsts, +): + # First try to lower via a simpler fori loop, which may optimize better. + fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop( + cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts + ) + del cond_jaxpr, body_jaxpr + if fori_jaxpr is None: + raise NotImplementedError(err) + + if fori_jaxpr.constvars: + raise NotImplementedError + + lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:] + # Reflect the changes of the pattern matcher to the context. + avals_in = ( + *ctx.avals_in[cond_nconsts:body_nconsts], + ctx.avals_in[body_nconsts], # the index + *ctx.avals_in[body_nconsts + 2:], + ) + + avals_out = tuple(ctx.avals_out[2:]) + ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out) + _, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts]) + + lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype) + length = arith_dialect.subi(ub, lb) + + for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True) + return (ub, ub, *for_out) + @register_lowering_rule(lax.cond_p) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b8098f40eccf..48c047697911 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -676,6 +676,22 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + def test_fori_loop_dynamic_bounds(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + grid=(1,) + ) + def kernel(o_ref): + zero = pl.program_id(0) + # Equivalent to 2 + 3. + o_ref[...] = jax.lax.broadcast( + jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape + ) + + np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + def test_fori_loop_tuple(self): @functools.partial( pl.pallas_call, From d0f17c0c04bec626a5e03cbf33a4dae43cfc8443 Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 4 Nov 2024 13:33:19 -0500 Subject: [PATCH 077/112] Make a direct linearize trace. This is an alternative to doing JVP followed by partial eval. The linearize trace has two parent traces, one for the primal computation and one for the tangent computation. If we make the tangent trace a DynamicJaxprTrace then we get staged linearization. If we make it the same as the primal trace then we get primal and tangent computations occurring in step (JVP). This is a neat trick enabled by stackless which now lives up to its name. With two parent traces we have a tree of traces not a linked list stack. Primitive ops can have their own linearization rules but as a fallback we can derive a linearization rule for a single op using jvp/partial-eval. For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can make this the default for linearize/grad. It should help with remat and AD through state which are awkward to express via partial eval. --- jax/_src/config.py | 10 ++++ jax/_src/interpreters/ad.py | 101 ++++++++++++++++++++++++++++++++++-- tests/api_test.py | 15 ++++++ 3 files changed, 123 insertions(+), 3 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 1c62f7125ee7..eff9b757b95b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -219,6 +219,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, sharding_in_types.value, + use_direct_linearize.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -263,6 +264,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, sharding_in_types.value, + use_direct_linearize.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -983,6 +985,7 @@ class _GlobalExtraJitContext(NamedTuple): threefry_partitionable: bool = False threefry_gpu_kernel_lowering: bool = False sharding_in_types: bool = False + use_direct_linearize: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 pgle_profiling_runs: int = 0 @@ -1025,6 +1028,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): threefry_partitionable: bool | None = None threefry_gpu_kernel_lowering: bool | None = None sharding_in_types: bool | None = None + use_direct_linearize: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None pgle_profiling_runs: int | None = None @@ -1318,6 +1322,12 @@ def _update_jax_memories_thread_local(val): 'avals have sharding on them.'), include_in_jit_key=True) +use_direct_linearize = bool_state( + name='jax_use_direct_linearize', + default=False, + help=('Use direct linearization instead JVP followed by partial eval'), + include_in_jit_key=True) + data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', default=False, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 99340e728545..91f061fd2210 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -39,7 +39,6 @@ as_hashable_function, weakref_lru_cache, partition_list) - zip = safe_zip map = safe_map def identity(x): return x @@ -106,7 +105,29 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): store.store(aux_primals) return out_primals, out_tangents +def direct_linearize(traceable, *primals, **kwargs): + has_aux = kwargs.pop('has_aux', False) + assert not has_aux + with core.take_current_trace() as parent_trace: + frame = pe.JaxprStackFrame() + tangent_trace = pe.DynamicJaxprTrace(frame) + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] + tag = core.TraceTag() + linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag) + tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] + with core.set_current_trace(linearize_trace): + ans = traceable.call_wrapped(*tracers) + + out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) + out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) + jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents) + out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] + del attrs_tracked # TODO: attrs + return out_primals, out_tangents_pvals, jaxpr, consts + def linearize(traceable, *primals, **kwargs): + if config.use_direct_linearize.value: + return direct_linearize(traceable, *primals, **kwargs) has_aux = kwargs.pop('has_aux', False) if not has_aux: jvpfun = jvp(traceable) @@ -444,15 +465,89 @@ def _primal_tangent_shapes_match(primal, tangent): call_param_updaters: dict[core.Primitive, Callable] = {} call_transpose_param_updaters: dict[core.Primitive, Callable] = {} +# -------------------- Linearize trace -------------------- + +class LinearizeTrace(Trace): + + def __init__(self, parent_trace, tangent_trace, tag): + self.tag = tag + self.parent_trace = parent_trace + self.tangent_trace = tangent_trace + + def to_primal_tangent_pair(self, val): + if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag: + return (val.primal, val.tangent) + else: + tangent_zero = Zero.from_primal_value(val) + return (val, tangent_zero) + + def process_primitive(self, primitive, args, params): + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) + if all(type(t) is Zero for t in tangents_in): + return primitive.bind_with_trace(self.parent_trace, primals_in, params) + lin = primitive_linearizations.get(primitive) + if lin is None: + lin = partial(fallback_linearize_rule, primitive) + with core.set_current_trace(self.parent_trace): + primal_out, linearized = lin(*primals_in, **params) + with core.set_current_trace(self.tangent_trace): + tangent_out = linearized(*tangents_in) + if primitive.multiple_results: + return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] + else: + return maybe_linearize_tracer(self, primal_out, tangent_out) + +def maybe_linearize_tracer(trace, primal, tangent): + if type(tangent) is Zero: + return primal + else: + return LinearizeTracer(trace, primal, tangent) + +def fallback_linearize_rule(prim, *args, **kwargs): + def call_prim(*args_): + return prim.bind(*args_, **kwargs) + with config.use_direct_linearize(False): + out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize( + lu.wrap_init(call_prim), *args, **kwargs) + def linearized(*tangents): + tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents)) + full_out = [pval.get_known() if pval.is_known() else next(tangents_out) + for pval in out_tangents_pvals] + assert next(tangents_out, None) is None + return full_out + return out_primals, linearized + +class LinearizeTracer(Tracer): + __slots__ = ['primal', 'tangent'] + + def __init__(self, trace, primal, tangent): + if config.enable_checks.value: + _primal_tangent_shapes_match(primal, tangent) + self._trace = trace + self.primal = primal + self.tangent = tangent + + @property + def aval(self): + return get_aval(self.primal) + + def full_lower(self): + if type(self.tangent) is Zero: + return core.full_lower(self.primal) + else: + return self + + def to_concrete_value(self): + return core.to_concrete_value(self.primal) + # -------------------- Primitives -------------------- primitive_jvps : dict[core.Primitive, Callable] = {} - primitive_transposes: dict[core.Primitive, Callable] = {} # transpose rules that internally perform reductions over the given named axes reducing_transposes: dict[core.Primitive, Callable] = {} - +primitive_linearizations: dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) diff --git a/tests/api_test.py b/tests/api_test.py index ae38f50460ab..ff7855b68991 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4807,6 +4807,21 @@ def add_one_and_dupe(x: int) -> tuple[int, int]: jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True) jax.eval_shape(jit_add_one_dupe, 0) # don't crash + def test_use_direct_linearize(self): + + def check_invariant_to_use_direct_linearize(f): + with config.use_direct_linearize(False): + ans1 = f() + with config.use_direct_linearize(True): + ans2 = f() + + self.assertEqual(ans1, ans2) + + def sin_of_sin(x): + return jnp.sin(jnp.sin(x)) + + check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) + class RematTest(jtu.JaxTestCase): From fee272e550109e7409e8ae6e992bbde7bd1f1b90 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 10:30:12 -0800 Subject: [PATCH 078/112] Remove internal KeyArray alias This was useful during the transition to typed PRNG keys, but is no longer necessary. It also makes generated HTML docs confusing: it's better to just use Array as we expect users to. --- jax/_src/blocked_sampler.py | 4 +- jax/_src/nn/initializers.py | 25 +++++---- jax/_src/random.py | 104 ++++++++++++++++++------------------ 3 files changed, 65 insertions(+), 68 deletions(-) diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 16da61d75b3f..3bc592d88246 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -23,7 +23,7 @@ Shape = random.Shape class SampleFn(Protocol): - def __call__(self, key: random.KeyArrayLike, *args, shape: Shape, + def __call__(self, key: ArrayLike, *args, shape: Shape, **kwargs) -> Array: ... @@ -43,7 +43,7 @@ def _compute_scalar_index(iteration_index: Sequence[int], def blocked_fold_in( - global_key: random.KeyArrayLike, + global_key: ArrayLike, total_size: Shape, block_size: Shape, tile_size: Shape, diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index eb1bb1609bbf..8086a97a3748 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -36,7 +36,6 @@ export = set_module('jax.nn.initializers') -KeyArray = Array # TODO: Import or define these to match # https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py. DTypeLikeFloat = Any @@ -48,13 +47,13 @@ @typing.runtime_checkable class Initializer(Protocol): @staticmethod - def __call__(key: KeyArray, + def __call__(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: raise NotImplementedError @export -def zeros(key: KeyArray, +def zeros(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of zeros. @@ -69,7 +68,7 @@ def zeros(key: KeyArray, return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) @export -def ones(key: KeyArray, +def ones(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of ones. @@ -100,7 +99,7 @@ def constant(value: ArrayLike, Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2, Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2, Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2, [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int], fan_out = out_size * receptive_field_size return fan_in, fan_out -def _complex_uniform(key: KeyArray, +def _complex_uniform(key: Array, shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ @@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray, theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, +def _complex_truncated_normal(key: Array, upper: ArrayLike, shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ @@ -314,7 +313,7 @@ def variance_scaling( dtype: the dtype of the weights. """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: shape = core.canonicalize_shape(shape) @@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0, Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -654,7 +653,7 @@ def delta_orthogonal( .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) diff --git a/jax/_src/random.py b/jax/_src/random.py index 6c04b0620080..4313d9036eda 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -55,8 +55,6 @@ Shape = Sequence[int] PRNGImpl = prng.PRNGImpl -KeyArray = Array -KeyArrayLike = ArrayLike UINT_DTYPES = prng.UINT_DTYPES @@ -69,8 +67,8 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -def _check_prng_key(name: str, key: KeyArrayLike, *, - allow_batched: bool = False) -> tuple[KeyArray, bool]: +def _check_prng_key(name: str, key: ArrayLike, *, + allow_batched: bool = False) -> tuple[Array, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): wrapped_key = key wrapped = False @@ -113,7 +111,7 @@ def _return_prng_keys(was_wrapped, key): return prng.random_unwrap(key) if was_wrapped else key -def _random_bits(key: KeyArray, bit_width: int, shape: Shape) -> Array: +def _random_bits(key: Array, bit_width: int, shape: Shape) -> Array: assert jnp.issubdtype(key.dtype, dtypes.prng_key) return prng.random_bits(key, bit_width=bit_width, shape=shape) @@ -188,7 +186,7 @@ def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl: def _key(ctor_name: str, seed: int | ArrayLike, - impl_spec: PRNGSpecDesc | None) -> KeyArray: + impl_spec: PRNGSpecDesc | None) -> Array: impl = resolve_prng_impl(impl_spec) if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( @@ -200,7 +198,7 @@ def _key(ctor_name: str, seed: int | ArrayLike, return prng.random_seed(seed, impl=impl) def key(seed: int | ArrayLike, *, - impl: PRNGSpecDesc | None = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> Array: """Create a pseudo-random number generator (PRNG) key given an integer seed. The result is a scalar array containing a key, whose dtype indicates @@ -220,7 +218,7 @@ def key(seed: int | ArrayLike, *, return _key('key', seed, impl) def PRNGKey(seed: int | ArrayLike, *, - impl: PRNGSpecDesc | None = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> Array: """Create a legacy PRNG key given an integer seed. This function produces old-style legacy PRNG keys, which are arrays @@ -248,7 +246,7 @@ def PRNGKey(seed: int | ArrayLike, *, return _return_prng_keys(True, _key('PRNGKey', seed, impl)) -def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: +def fold_in(key: ArrayLike, data: IntegerArray) -> Array: """Folds in data to a PRNG key to form a new PRNG key. Args: @@ -267,7 +265,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: return _return_prng_keys(wrapped, key_out) -def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: +def _split(key: Array, num: int | tuple[int, ...] = 2) -> Array: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait # to always enable_custom_prng @@ -278,7 +276,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: shape = tuple(num) if isinstance(num, Sequence) else (num,) return prng.random_split(key, shape=shape) -def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: +def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: @@ -293,22 +291,22 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: return _return_prng_keys(wrapped, _split(typed_key, num)) -def _key_impl(keys: KeyArray) -> str | PRNGSpec: +def _key_impl(keys: Array) -> str | PRNGSpec: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) impl = keys_dtype._impl return impl.name if impl.name in prng.prngs else PRNGSpec(impl) -def key_impl(keys: KeyArrayLike) -> str | PRNGSpec: +def key_impl(keys: ArrayLike) -> str | PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) return _key_impl(typed_keys) -def _key_data(keys: KeyArray) -> Array: +def _key_data(keys: Array) -> Array: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) return prng.random_unwrap(keys) -def key_data(keys: KeyArrayLike) -> Array: +def key_data(keys: ArrayLike) -> Array: """Recover the bits of key data underlying a PRNG key array.""" keys, _ = _check_prng_key("key_data", keys, allow_batched=True) return _key_data(keys) @@ -345,7 +343,7 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) -def bits(key: KeyArrayLike, +def bits(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeUInt | None = None) -> Array: """Sample uniform bits in the form of unsigned integers. @@ -374,7 +372,7 @@ def bits(key: KeyArrayLike, return _random_bits(key, bit_width, shape) -def uniform(key: KeyArrayLike, +def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., @@ -444,7 +442,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: lax.reshape(floats * (maxval - minval) + minval, shape)) -def randint(key: KeyArrayLike, +def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, @@ -533,7 +531,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: return lax.add(minval, lax.convert_element_type(random_offset, dtype)) -def permutation(key: KeyArrayLike, +def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, independent: bool = False) -> Array: @@ -596,7 +594,7 @@ def _shuffle(key, x, axis) -> Array: return x -def choice(key: KeyArrayLike, +def choice(key: ArrayLike, a: int | ArrayLike, shape: Shape = (), replace: bool = True, @@ -677,7 +675,7 @@ def choice(key: KeyArrayLike, arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:]) -def normal(key: KeyArrayLike, +def normal(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample standard normal random values with given shape and float dtype. @@ -730,7 +728,7 @@ def _normal_real(key, shape, dtype) -> Array: return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u)) -def multivariate_normal(key: KeyArrayLike, +def multivariate_normal(key: ArrayLike, mean: RealArray, cov: RealArray, shape: Shape | None = None, @@ -813,7 +811,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: return result -def truncated_normal(key: KeyArrayLike, +def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, @@ -879,7 +877,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) -def bernoulli(key: KeyArrayLike, +def bernoulli(key: ArrayLike, p: RealArray = np.float32(0.5), shape: Shape | None = None) -> Array: r"""Sample Bernoulli random values with given shape and mean. @@ -924,7 +922,7 @@ def _bernoulli(key, p, shape) -> Array: return uniform(key, shape, lax.dtype(p)) < p -def beta(key: KeyArrayLike, +def beta(key: ArrayLike, a: RealArray, b: RealArray, shape: Shape | None = None, @@ -985,7 +983,7 @@ def _beta(key, a, b, shape, dtype) -> Array: return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled) -def cauchy(key: KeyArrayLike, +def cauchy(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Cauchy random values with given shape and float dtype. @@ -1024,7 +1022,7 @@ def _cauchy(key, shape, dtype) -> Array: return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5)))) -def dirichlet(key: KeyArrayLike, +def dirichlet(key: ArrayLike, alpha: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1096,7 +1094,7 @@ def _softmax(x, axis) -> Array: return unnormalized / unnormalized.sum(axis, keepdims=True) -def exponential(key: KeyArrayLike, +def exponential(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Exponential random values with given shape and float dtype. @@ -1135,7 +1133,7 @@ def _exponential(key, shape, dtype) -> Array: return lax.neg(lax.log1p(lax.neg(u))) -def _gamma_one(key: KeyArray, alpha, log_space) -> Array: +def _gamma_one(key: Array, alpha, log_space) -> Array: # Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang # The algorithm can also be founded in: # https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables @@ -1263,7 +1261,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space): multiple_results=False), platform='cpu') batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule -def gamma(key: KeyArrayLike, +def gamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1310,7 +1308,7 @@ def gamma(key: KeyArrayLike, return _gamma(key, a, shape=shape, dtype=dtype) -def loggamma(key: KeyArrayLike, +def loggamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1452,7 +1450,7 @@ def _poisson(key, lam, shape, dtype) -> Array: return lax.select(lam == 0, jnp.zeros_like(result), result) -def poisson(key: KeyArrayLike, +def poisson(key: ArrayLike, lam: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: @@ -1497,7 +1495,7 @@ def poisson(key: KeyArrayLike, return _poisson(key, lam, shape, dtype) -def gumbel(key: KeyArrayLike, +def gumbel(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: """Sample Gumbel random values with given shape and float dtype. @@ -1533,7 +1531,7 @@ def _gumbel(key, shape, dtype) -> Array: uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) -def categorical(key: KeyArrayLike, +def categorical(key: ArrayLike, logits: RealArray, axis: int = -1, shape: Shape | None = None) -> Array: @@ -1575,7 +1573,7 @@ def categorical(key: KeyArrayLike, axis=axis) -def laplace(key: KeyArrayLike, +def laplace(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Laplace random values with given shape and float dtype. @@ -1612,7 +1610,7 @@ def _laplace(key, shape, dtype) -> Array: return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) -def logistic(key: KeyArrayLike, +def logistic(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample logistic random values with given shape and float dtype. @@ -1648,7 +1646,7 @@ def _logistic(key, shape, dtype): return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) -def pareto(key: KeyArrayLike, +def pareto(key: ArrayLike, b: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1697,7 +1695,7 @@ def _pareto(key, b, shape, dtype) -> Array: return lax.exp(e / b) -def t(key: KeyArrayLike, +def t(key: ArrayLike, df: RealArray, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: @@ -1749,7 +1747,7 @@ def _t(key, df, shape, dtype) -> Array: return n * jnp.sqrt(half_df / g) -def chisquare(key: KeyArrayLike, +def chisquare(key: ArrayLike, df: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1801,7 +1799,7 @@ def _chisquare(key, df, shape, dtype) -> Array: return chi2 -def f(key: KeyArrayLike, +def f(key: ArrayLike, dfnum: RealArray, dfden: RealArray, shape: Shape | None = None, @@ -1865,7 +1863,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: return f -def rademacher(key: KeyArrayLike, +def rademacher(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeInt = int) -> Array: r"""Sample from a Rademacher distribution. @@ -1900,7 +1898,7 @@ def _rademacher(key, shape, dtype) -> Array: return (2 * bernoulli_samples - 1).astype(dtype) -def maxwell(key: KeyArrayLike, +def maxwell(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample from a one sided Maxwell distribution. @@ -1940,7 +1938,7 @@ def _maxwell(key, shape, dtype) -> Array: return jnp.linalg.norm(norm_rvs, axis=-1) -def double_sided_maxwell(key: KeyArrayLike, +def double_sided_maxwell(key: ArrayLike, loc: RealArray, scale: RealArray, shape: Shape = (), @@ -1992,7 +1990,7 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array: return random_sign * maxwell_rvs * scale + loc -def weibull_min(key: KeyArrayLike, +def weibull_min(key: ArrayLike, scale: RealArray, concentration: RealArray, shape: Shape = (), @@ -2038,7 +2036,7 @@ def _weibull_min(key, scale, concentration, shape, dtype) -> Array: def orthogonal( - key: KeyArrayLike, + key: ArrayLike, n: int, shape: Shape = (), dtype: DTypeLikeFloat = float @@ -2073,7 +2071,7 @@ def orthogonal( return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2])) def generalized_normal( - key: KeyArrayLike, + key: ArrayLike, p: float, shape: Shape = (), dtype: DTypeLikeFloat = float @@ -2108,7 +2106,7 @@ def generalized_normal( return r * g ** (1 / p) def ball( - key: KeyArrayLike, + key: ArrayLike, d: int, p: float = 2, shape: Shape = (), @@ -2140,7 +2138,7 @@ def ball( return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None] -def rayleigh(key: KeyArrayLike, +def rayleigh(key: ArrayLike, scale: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2193,7 +2191,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: ray = lax.mul(scale, sqrt_u) return ray -def wald(key: KeyArrayLike, +def wald(key: ArrayLike, mean: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2251,7 +2249,7 @@ def _wald(key, mean, shape, dtype) -> Array: w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x) return w -def geometric(key: KeyArrayLike, +def geometric(key: ArrayLike, p: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: @@ -2304,7 +2302,7 @@ def _geometric(key, p, shape, dtype) -> Array: return g.astype(dtype) -def triangular(key: KeyArrayLike, +def triangular(key: ArrayLike, left: RealArray, mode: RealArray, right: RealArray, @@ -2368,7 +2366,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: return tri -def lognormal(key: KeyArrayLike, +def lognormal(key: ArrayLike, sigma: RealArray = np.float32(1), shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2573,7 +2571,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: def binomial( - key: KeyArray, + key: Array, n: RealArray, p: RealArray, shape: Shape | None = None, From 2c9b917b9d01149f4b6b5db1523fa742af413ced Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 20 Nov 2024 10:35:16 -0800 Subject: [PATCH 079/112] Don't psum over auto mesh dims in _unmentioned2. PiperOrigin-RevId: 698440525 --- jax/experimental/shard_map.py | 9 +++++---- tests/shard_map_test.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 4ad248c17ee2..c2673b55dd9a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1547,10 +1547,11 @@ def fun(*res_and_args): return jaxpr -def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: +def _unmentioned2(mesh: Mesh, names: AxisNames, + auto: frozenset[AxisName]) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} + name_set = {n for ns in names.values() for n in ns} | auto return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set] @@ -1559,7 +1560,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero else x if rewrite or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns)))) + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) @@ -1577,7 +1578,7 @@ def fun_trans(out_cts, args): ) out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero else x if rewrite - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns))) + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) for ns, x in zip(in_names, out)] return out diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 84017bab5122..2a343f7ba784 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2046,6 +2046,29 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_grad_nested_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) + + def g(x): + return x * x + + def h(x): + return shard_map(g, mesh, + in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x).sum() + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) + def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) From 9584ee3bb9c3a48299635a7c0a11df1029cf9f59 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 10:41:24 -0800 Subject: [PATCH 080/112] [pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing indexing at all! PiperOrigin-RevId: 698442820 --- tests/pallas/mosaic_gpu_test.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 48c047697911..a4bbc67ee14f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1218,33 +1218,33 @@ def kernel_body(x_smem, o_smem): np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) def test_emit_with_parallel_grid(self): - self.skipTest("Enable once we support multiple levels of indexing") - - num_steps = 4 + num_steps1 = 4 + num_steps2 = 5 def kernel(x_gmem, o_gmem): - gmem_slice = pl.ds(pl.program_id(0) * 32, 32) + pid = pl.program_id(0) plgpu.emit_pipeline( kernel_body, - in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - grid=(num_steps,), + in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + grid=(num_steps2,), max_concurrent_steps=2, - )(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice]) + )(x_gmem, o_gmem) def kernel_body(x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 - x = jnp.arange(4 * 32 * num_steps * 16) - x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) + x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) kernel_fn = pl.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - grid=(4, 1), + grid=(num_steps1,), ) - np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + y = x + 1.0 + np.testing.assert_array_equal(kernel_fn(x), y) def test_emit_with_2d_grid(self): num_steps1 = 4 From 621e39de27098a941094fa332cf03f42018f3b91 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 10:47:23 -0800 Subject: [PATCH 081/112] Set __module__ attribute of jax.numpy.linalg APIs --- jax/_src/numpy/linalg.py | 36 ++++++++++++++++++++++++++++++++- tests/package_structure_test.py | 1 + 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 76a4abff48ad..be6828c36e6a 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -35,10 +35,13 @@ from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg +export = set_module('jax.numpy.linalg') + + class EighResult(NamedTuple): eigenvalues: jax.Array eigenvectors: jax.Array @@ -67,6 +70,7 @@ def _H(x: ArrayLike) -> Array: def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 +@export @partial(jit, static_argnames=['upper']) def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """Compute the Cholesky decomposition of a matrix. @@ -191,6 +195,7 @@ def svd( ... +@export @partial( jit, static_argnames=( @@ -311,6 +316,7 @@ def svd( ) +@export @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: """Raise a square matrix to an integer power. @@ -392,6 +398,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: return result +@export @jit def matrix_rank( M: ArrayLike, rtol: ArrayLike | None = None, *, @@ -496,6 +503,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: return sign_diag * sign_taus, log_abs_det +@export @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ @@ -675,6 +683,7 @@ def _det_jvp(primals, tangents): return y, jnp.trace(z, axis1=-1, axis2=-2) +@export @jit def det(a: ArrayLike) -> Array: """ @@ -711,6 +720,7 @@ def det(a: ArrayLike) -> Array: raise ValueError(msg.format(a_shape)) +@export def eig(a: ArrayLike) -> tuple[Array, Array]: """ Compute the eigenvalues and eigenvectors of a square array. @@ -756,6 +766,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: return w, v +@export @jit def eigvals(a: ArrayLike) -> Array: """ @@ -793,6 +804,7 @@ def eigvals(a: ArrayLike) -> Array: compute_right_eigenvectors=False)[0] +@export @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: @@ -848,6 +860,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, return EighResult(w, v) +@export @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ @@ -884,6 +897,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: # TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. +@export def pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False, *, rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: @@ -997,6 +1011,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents): return p, p_dot +@export @jit def inv(a: ArrayLike) -> Array: """Return the inverse of a square matrix @@ -1057,6 +1072,7 @@ def inv(a: ArrayLike) -> Array: arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) +@export @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, @@ -1222,6 +1238,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... +@export @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: """Compute the QR decomposition of an array @@ -1305,6 +1322,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: return QRResult(q, r) +@export @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: """Solve a linear system of equations @@ -1408,6 +1426,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) +@export def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: """ @@ -1448,6 +1467,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, return _jit_lstsq(a, b, rcond) +@export def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): r"""Compute the cross-product of two 3D vectors @@ -1493,6 +1513,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): return jnp.cross(x1, x2, axis=axis) +@export def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute the outer product of two 1-dimensional arrays. @@ -1523,6 +1544,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: return x1[:, None] * x2[None, :] +@export def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: """Compute the norm of a matrix or stack of matrices. @@ -1553,6 +1575,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose a matrix or stack of matrices. @@ -1608,6 +1631,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) +@export def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Compute the vector norm of a vector or batch of vectors. @@ -1652,6 +1676,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa return norm(x, axis=axis, keepdims=keepdims, ord=ord) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -1702,6 +1727,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, preferred_element_type=preferred_element_type) +@export def matmul(x1: ArrayLike, x2: ArrayLike, /, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -1762,6 +1788,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *, preferred_element_type=preferred_element_type) +@export def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, precision: PrecisionLike = None, @@ -1843,6 +1870,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, preferred_element_type=preferred_element_type) +@export def svdvals(x: ArrayLike, /) -> Array: """Compute the singular values of a matrix. @@ -1867,6 +1895,7 @@ def svdvals(x: ArrayLike, /) -> Array: return svd(x, compute_uv=False, hermitian=False) +@export def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: """Extract the diagonal of an matrix or stack of matrices. @@ -1907,6 +1936,7 @@ def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) +@export def tensorinv(a: ArrayLike, ind: int = 2) -> Array: """Compute the tensor inverse of an array. @@ -1949,6 +1979,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array: return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape) +@export def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array: """Solve the tensor equation a x = b for x. @@ -1998,6 +2029,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) return solve(a_arr, b_arr.ravel()).reshape(out_shape) +@export def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: """Efficiently compute matrix products between a sequence of arrays. @@ -2090,6 +2122,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) - optimize='optimal', precision=precision) +@export @partial(jit, static_argnames=['p']) def cond(x: ArrayLike, p=None): """Compute the condition number of a matrix. @@ -2149,6 +2182,7 @@ def cond(x: ArrayLike, p=None): return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) +@export def trace(x: ArrayLike, /, *, offset: int = 0, dtype: DTypeLike | None = None) -> Array: """Compute the trace of a matrix. diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 9bc8d0f6d71c..25468c4ba700 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -40,6 +40,7 @@ class PackageStructureTest(jtu.JaxTestCase): "number", "object_", "printoptions", "save", "savez", "set_printoptions", "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] ), + _mod("jax.numpy.linalg"), _mod("jax.nn.initializers"), _mod( "jax.tree_util", From dfe27a16825663ea3a90417ad452e99dc43d7f53 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 20 Nov 2024 14:53:52 -0500 Subject: [PATCH 082/112] Mention stackless in the release notes. --- CHANGELOG.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37fd68bce39d..be9aaebcd615 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.36 * Breaking Changes + * This release lands "stackless", an internal change to JAX's tracing + machinery. We made trace dispatch purely a function of context rather than a + function of both context and data. This let us delete a lot of machinery for + managing data-dependent tracing: levels, sublevels, `post_process_call`, + `new_base_main`, `custom_bind`, and so on. The change should only affect + users that use JAX internals. + + If you do use JAX internals then you may need to + update your code (see + https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f + for clues about how to do this). There might also be version skew + issues with JAX libraries that do this. If you find this change breaks your + non-JAX-internals-using code then try the + `config.jax_data_dependent_tracing_fallback` flag as a workaround, and if + you need help updating your code then please file a bug. * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` or with `enable_xla=False` have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` From 40fc6598f96999271a3c19cfaab6f02579c003d6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 13:06:39 -0800 Subject: [PATCH 083/112] [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs. Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too. Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`. PiperOrigin-RevId: 698493184 --- jax/BUILD | 1 + jax/_src/config.py | 5 ++++- jax/_src/core.py | 30 +++++++++++++++++++++------- jax/_src/mesh.py | 25 ++++++++++++++++++++--- jax/_src/pallas/core.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pjit.py | 32 ++++++++++++++++++++++++++---- jax/_src/state/primitives.py | 5 ++++- jax/experimental/shard_map.py | 5 ++++- 9 files changed, 88 insertions(+), 19 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 26694fec2ad3..64bfa627f42e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -451,6 +451,7 @@ pytype_strict_library( ":deprecations", ":dtypes", ":effects", + ":mesh", ":pretty_printer", ":source_info_util", ":traceback_util", diff --git a/jax/_src/config.py b/jax/_src/config.py index eff9b757b95b..2723b4f90d3b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -209,7 +209,9 @@ def trace_context(): Values included in this set should also most likely be included in the C++ JIT state, which is handled separately. """ - return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value, + return (axis_env_state.value, mesh_context_manager.value, + xla_metadata_context_manager.value, + abstract_mesh_context_manager.value, compute_on_context_manager.value, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, @@ -969,6 +971,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: trace_state = config_ext.Config(None, include_in_jit_key=True) axis_env_state = config_ext.Config((), include_in_jit_key=True) mesh_context_manager = config_ext.Config((), include_in_jit_key=True) + abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) else: diff --git a/jax/_src/core.py b/jax/_src/core.py index cbf3282fb2cc..1bd3fb4fa889 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -38,6 +38,7 @@ from jax._src import config from jax._src import effects from jax._src import compute_on +from jax._src import mesh as mesh_lib from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -1596,6 +1597,23 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) + +def get_sharding(sharding, ndim): + from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore + + if sharding is not None: + assert len(sharding.spec) == ndim + return sharding + + context_mesh = mesh_lib.mesh_context.mesh + # TODO(yashkatariya): Error out and ask users to set the context mesh in their + # code. + if context_mesh is None: + return None + assert sharding is None + return NamedSharding(context_mesh, P(*[None] * ndim)) + + class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 @@ -1605,20 +1623,18 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None): self.dtype = _dtype_object(dtype) self.weak_type = weak_type if config.sharding_in_types.value: - if sharding is not None: - assert len(sharding.spec) == len(self.shape) - self.sharding = sharding + self.sharding = get_sharding(sharding, len(self.shape)) - def update(self, shape=None, dtype=None, weak_type=None, sharding=None): + def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: shape = self.shape if dtype is None: dtype = self.dtype if weak_type is None: weak_type = self.weak_type - if sharding is None: - sharding = getattr(self, 'sharding', None) - return ShapedArray(shape, dtype, weak_type, sharding=sharding) + if 'sharding' not in kwargs: + kwargs['sharding'] = getattr(self, 'sharding', None) + return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) size = property(lambda self: diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index a2ab261fa0e9..3d0e1b0cccf5 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -107,6 +107,9 @@ class AxisTypes(enum.Enum): User = enum.auto() Collective = enum.auto() + def __repr__(self): + return self.name + def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: if axis_types is None: return {} @@ -452,14 +455,22 @@ def local_mesh(self): _raise_value_error("local_mesh") def __enter__(self): - raise RuntimeError("AbstractMesh is not a context manager") + mesh_context.stack.append(self) + mesh_context.mesh = self + jax_config.abstract_mesh_context_manager.set_local( + tuple(m for m in mesh_context.stack if m is not None)) + return self def __exit__(self, exc_type, exc_value, traceback): - raise RuntimeError("AbstractMesh is not a context manager") + mesh_context.stack.pop() + mesh_context.mesh = mesh_context.stack[-1] + jax_config.abstract_mesh_context_manager.set_local( + tuple(m for m in mesh_context.stack if m is not None)) + return False @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - jax_config.mesh_context_manager.set_local(mesh) + jax_config.abstract_mesh_context_manager.set_local(mesh) return @@ -467,3 +478,11 @@ def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): # property raises an exception unconditionally. Remove this once that is fixed. def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") + + +class MeshContext(threading.local): + def __init__(self): + self.stack = [None] + self.mesh = self.stack[-1] + +mesh_context = MeshContext() diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 72ed07674f1f..cf1e0b524963 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -873,7 +873,7 @@ def get_grid_mapping( ) # The inputs for the index maps index_map_avals = ( - (index_map_grid_aval,) * len(grid_spec.grid)) + (index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid)) index_map_tree = tree_util.tree_structure((index_map_avals, {})) num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index be4102dff716..c9e20843a806 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1380,7 +1380,7 @@ def _masked_swap_lowering_rule( 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) for b in ref_block_shape ] - mem_aval = aval_out.update(shape=tuple(mem_slice_shape)) + mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None) mem_aval_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)) if need_stride: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f1844c7ba13b..aff956862753 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -16,6 +16,7 @@ from collections import defaultdict from collections.abc import Callable, Sequence, Iterable +import contextlib import dataclasses from functools import partial import inspect @@ -637,10 +638,13 @@ def _infer_params_impl( in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) attr_token = _attr_token(flat_fun, in_type) - jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( - flat_fun, in_type, attr_token, dbg, - HashableFunction(res_paths, closure=()), - IgnoreKey(ji.inline)) + + abstract_mesh = get_abstract_mesh(in_type) + with abstract_mesh: + jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( + flat_fun, in_type, attr_token, dbg, + HashableFunction(res_paths, closure=()), + IgnoreKey(ji.inline)) _attr_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( @@ -683,6 +687,26 @@ def _infer_params_impl( attrs_tracked), args_flat +def get_abstract_mesh(in_avals): + if not config.sharding_in_types.value: + return contextlib.nullcontext() + m = None + for a in in_avals: + # TODO(yashkatariya): Remove this when mesh context can be set by the user. + if a.sharding is None: # type: ignore + continue + if m is not None and m != a.sharding.mesh: + raise ValueError( + f'Mesh for all inputs should be equal. Got one mesh: {m} and' + f' another mesh: {a.sharding.mesh}') + m = a.sharding.mesh # type: ignore + # TODO(yashkatariya): Remove this when mesh context can be set by the user. + if m is None: + return contextlib.nullcontext() + assert m is not None + return m + + class InferParamsCacheEntry: """Mutable value object for _infer_params_cached.""" __slots__ = ['pjit_params'] diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 0897e778d079..14d42ad0809c 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -214,7 +214,10 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) - out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype) + # TODO(yashkatariya): Transform the sharding too instead of setting it to + # None. + out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype, + sharding=None) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c2673b55dd9a..07f631f6ec49 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -483,7 +483,8 @@ def _shard_map_staging( in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) - with core.extend_axis_env_nd(list(mesh.shape.items())): + with (core.extend_axis_env_nd(list(mesh.shape.items())), + pjit.get_abstract_mesh(in_avals_)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: @@ -547,6 +548,8 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) + # TODO(yashkatariya): Reset the mesh properly based on the input avals if the + # mesh of shard_map specifies collective axes. if config.sharding_in_types.value: spec = _names_to_pspec(names)._normalized_spec(aval.ndim) new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) From 9d2f62f811e23d4c9b2c33d923fff70ed78a4acf Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 20 Nov 2024 14:03:12 -0800 Subject: [PATCH 084/112] [Pallas TPU] Support masked store PiperOrigin-RevId: 698514079 --- jax/_src/pallas/mosaic/lowering.py | 23 ++++++++++++++++--- tests/pallas/tpu_pallas_test.py | 36 ++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c9e20843a806..1f0062cad0f9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -42,6 +42,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -1315,12 +1316,20 @@ def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): ref, transforms, val, mask = args_tree.unflatten(args_flat) - ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) + ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten( + ctx.avals_in + ) (*prev_transforms, idx) = transforms (*_, idx_aval) = transforms_avals if mask is not None: - raise NotImplementedError + if val_aval.dtype.itemsize != 4: + raise NotImplementedError("masked swap with non-32-bit data") + if val_aval.shape != mask_aval.shape: + raise ValueError( + "Expected value and mask to have the same shape, but got" + f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}." + ) ref_block_shape, *_ = ctx.block_shapes ref, ref_block_shape = _transform_ref( @@ -1351,6 +1360,8 @@ def _masked_swap_lowering_rule( need_stride = not all((s is None or s == 1) for s in strides) if is_smem_store: + if mask is not None: + raise ValueError("SMEM store does not support masks") if val_aval.shape: raise ValueError("Can only store scalars to SMEM") result = memref.load(ref, starts) @@ -1399,9 +1410,15 @@ def _masked_swap_lowering_rule( result = _maybe_cast_load_to_bool(val_aval, result) if need_stride: + if mask is not None: + raise NotImplementedError("masked swap with strided store") tpu.StridedStoreOp(val, ref, starts, strides) - else: + elif jaxlib_version <= (0, 4, 35): + if mask is not None: + raise NotImplementedError("masked swap with vector store") vector.StoreOp(val, ref, starts) + else: + tpu.VectorStoreOp(val, ref, starts, [], mask=mask) return result diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 347a06c50323..9c4788d7447f 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1723,6 +1723,42 @@ def test(x: jax.Array) -> jax.Array: y = test(x) np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) + def test_masked_store(self): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("Test requires masked store support") + shape = (16, 256) + mask_shape = (10, 130) + mask_start = (4, 5) + dtype = jnp.float32 + def body(scalar_ref, x_ref, o_ref): + o_ref[...] = jnp.full(shape, -1, dtype=dtype) + b0, b1 = scalar_ref[0], scalar_ref[1] + e0, e1 = b0 + mask_shape[0], b1 + mask_shape[1] + iota0 = lax.broadcasted_iota(jnp.int32, shape, 0) + iota1 = lax.broadcasted_iota(jnp.int32, shape, 1) + mask0 = jnp.logical_and(b0 <= iota0, iota0 < e0) + mask1 = jnp.logical_and(b1 <= iota1, iota1 < e1) + pl.store( + o_ref, + (slice(None), slice(None)), + x_ref[...], + mask=jnp.logical_and(mask0, mask1), + ) + + s = jnp.array(mask_start, jnp.int32) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + out = pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + ), + )(s, x) + slices = tuple(slice(b, b + l) for b, l in zip(mask_start, mask_shape)) + expected = jnp.full(shape, -1, dtype=dtype) + expected = expected.at[slices].set(x[slices]) + np.testing.assert_array_equal(out, expected) + class PallasUXTest(PallasBaseTest): From 9b941808463ee614a541c5174ea98cb5c58a080b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 14:29:59 -0800 Subject: [PATCH 085/112] [sharding_in_types] Add slice_p and squeeze_p sharding rule to make flash attention work in backward pass For `slice_p`'s sharding rule, I error out if the operand dim is sharded and the output dim is not divisible by that axis size. I am working on a design to make JAX support uneven sharding at the top level after which slice_p's sharding rule can just `return operand.sharding`. Another option is to add `out_sharding` to `slice` but after uneven sharding support lands, it won't be necessary. PiperOrigin-RevId: 698522980 --- jax/_src/lax/lax.py | 14 ++++++++++++-- jax/_src/lax/slicing.py | 34 +++++++++++++++++++++++++++++++--- jax/_src/pallas/core.py | 4 ++++ tests/pjit_test.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 39c5bca5819c..79e48c440271 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4527,6 +4527,12 @@ def _squeeze_dtype_rule(operand, *, dimensions): def _squeeze_shape_rule(operand, *, dimensions): return _compute_squeeze_shape(np.shape(operand), dimensions) +def _squeeze_sharding_rule(operand, *, dimensions): + dims_set = set(dimensions) + new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) + if i not in dims_set) + return NamedSharding(operand.sharding.mesh, P(*new_spec)) + def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) if len(dims_set) != len(dimensions): @@ -4555,7 +4561,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze') + 'squeeze', sharding_rule=_squeeze_sharding_rule) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -4563,7 +4569,11 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. - return [mlir.reshape(ctx, operand, ctx.avals_out[0])] + aval_out, = ctx.avals_out + out = mlir.reshape(ctx, operand, aval_out) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(squeeze_p, _squeeze_lower) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 40a04ff11d2c..117c8b655152 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -42,6 +42,7 @@ _input_dtype, standard_primitive, ) +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike, Shape @@ -1270,6 +1271,29 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): return tuple(core.stride_dim(d, window_size=1, window_stride=s) for d, s in zip(diff, strides)) +def _get_sub_spec_size(mesh, sub_spec): + if isinstance(sub_spec, tuple): + return math.prod(mesh.shape[s] for s in sub_spec) + return mesh.shape[sub_spec] + +def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _slice_shape_rule(operand, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + mesh = operand.sharding.mesh + new_spec = [] + for op_sh, out_sh, op_spec in safe_zip( + operand.shape, out_shape, operand.sharding.spec): + if (op_sh != out_sh and op_spec is not None and + out_sh % _get_sub_spec_size(mesh, op_spec) != 0): + raise NotImplementedError( + f"slicing on sharded dims where out dim ({out_sh}) is not divisble by" + f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" + f" ({op_spec}) is not implemented.") + new_spec.append(op_spec) + return NamedSharding(mesh, P(*new_spec)) + def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape @@ -1308,7 +1332,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, out = slice(operand, new_start_indices, new_limit_indices, new_strides) return out, bdim -slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice') +slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', + sharding_rule=_slice_sharding_rule) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1333,8 +1358,11 @@ def _slice_impl(x, start_indices, limit_indices, strides): def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): strides = strides or [1] * len(start_indices) aval_out, = ctx.avals_out - return [mlir.slice_op(ctx, x, aval_out, - start_indices=start_indices, limit_indices=limit_indices, strides=strides)] + out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(slice_p, _slice_lower) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index cf1e0b524963..acbf0d4f7ed5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -219,6 +219,10 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' + @property + def sharding(self): + return self.inner_aval.sharding + def update_weak_type(self, weak_type): return AbstractMemoryRef( self.inner_aval.update_weak_type(weak_type), self.memory_space) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e32424cfdded..e52c805ef5e6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5285,6 +5285,43 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_slice(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(4, 4) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) + + @jax.jit + def f(x): + y = lax.slice(x, (0, 0), (4, 3)) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertIn('@Sharding', f.lower(arr).as_text()) + + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) + + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) + + def test_squeeze(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(4, 4, 1) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) + + @jax.jit + def f(x): + y = lax.squeeze(x, (2,)) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertIn('@Sharding', f.lower(arr).as_text()) + self.assertArraysEqual(out, np.squeeze(np_inp, axis=2)) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 6fe78042b51d066e1b886dcf8f77df627831c00e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 20 Nov 2024 14:37:36 -0800 Subject: [PATCH 086/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e763f8875b0a9bfca876be9b02c874979e55422a. PiperOrigin-RevId: 698525361 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a554cfd03687..327e4ca422ac 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fcee07f619a765db815d9ed4e2bc229275818a2b" -XLA_SHA256 = "1dd144e64e2c2dcc20a2130e10607fec7b3a810926ba912918dd5437698a3375" +XLA_COMMIT = "e763f8875b0a9bfca876be9b02c874979e55422a" +XLA_SHA256 = "7b6a33894c6510167cac6e0ab7a6331ffa84e7fcaaa1d3b1c462ec5ecacb0682" def repo(): tf_http_archive( From f749fca760cbcd019bed8f9e1a64cf525c0bcb14 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 14:50:06 -0800 Subject: [PATCH 087/112] [array api] use most recent version of array_api_tests --- .github/workflows/jax-array-api.yml | 2 +- pyproject.toml | 1 - tests/array_api_skips.txt | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 763a4c04be5d..8f2029eb9191 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'a3f3f376308e64f0ac15b307dfe27be945409e41' # Latest commit as of 2024-11-14 + ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index 73e1c51fc8af..d688f7fbbf01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ filterwarnings = [ # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:The .* method is good for exploring strategies.*", # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2ac2edcdfd99..e1d4c35eae68 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -6,6 +6,7 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking +array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] From 2699e9507e462047fb853c32768812467ec1c13c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 15:13:14 -0800 Subject: [PATCH 088/112] DOC: add examples for jax.lax.pad --- jax/_src/lax/lax.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 79e48c440271..1bf1ea816ca7 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1306,6 +1306,36 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, Returns: The ``operand`` array with padding value ``padding_value`` inserted in each dimension according to the ``padding_config``. + + Examples: + >>> from jax import lax + >>> import jax.numpy as jnp + + Pad a 1-dimensional array with zeros, We'll specify two zeros in front and + three at the end: + + >>> x = jnp.array([1, 2, 3, 4]) + >>> lax.pad(x, 0, [(2, 3, 0)]) + Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32) + + Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero + between each value: + + >>> lax.pad(x, 0, [(0, 0, 1)]) + Array([1, 0, 2, 0, 3, 0, 4], dtype=int32) + + Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad + size of 2 in each dimension: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) + Array([[-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 1, 2, 3, -1, -1], + [-1, -1, 4, 5, 6, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) From 17825882d2bb87b84387963bfaf53ce191cbf71b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 16:21:45 -0800 Subject: [PATCH 089/112] jax.lax.pad: improve input validation --- jax/_src/lax/lax.py | 3 ++- tests/lax_test.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 79e48c440271..934f100ffe34 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4441,7 +4441,8 @@ def _pad_dtype_rule(operand, padding_value, *, padding_config): return _input_dtype(operand, padding_value) def _pad_shape_rule(operand, padding_value, *, padding_config): - del padding_value + if np.ndim(padding_value) != 0: + raise ValueError(f"padding_value must be a scalar; got {np.shape(padding_value)=}") op_shape = np.shape(operand) if not len(padding_config) == np.ndim(operand): raise ValueError("length of padding_config must equal the number of axes " diff --git a/tests/lax_test.py b/tests/lax_test.py index 78bc5857acb7..10fa8c006184 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1597,6 +1597,8 @@ def testPadAgainstNumpy(self, shape, dtype, pads): self._CheckAgainstNumpy(numpy_op, op, args_maker) def testPadErrors(self): + with self.assertRaisesRegex(ValueError, "padding_value must be a scalar"): + lax.pad(np.zeros(2), np.zeros(2), [(0, 0, 0)]) with self.assertRaisesRegex(ValueError, "padding_config"): lax.pad(np.zeros(2), 0., [(0, 1, 0), (0, 1, 0)]) with self.assertRaisesRegex(ValueError, "interior padding in padding_config must be nonnegative"): From bf7f9aa8f27da525bc0a1e42a3b6e15c1f93b2f4 Mon Sep 17 00:00:00 2001 From: barnesjoseph Date: Wed, 20 Nov 2024 16:38:58 -0800 Subject: [PATCH 090/112] Adds Google Sans font --- docs/_static/style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/_static/style.css b/docs/_static/style.css index d801c2a412a6..36b54b8432f0 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,4 +1,5 @@ @import url("theme.css"); +@import url('https://fonts.googleapis.com/css2?family=Google+Sans'); /* Base LP sidebar modifications */ body:has(.hero) .sidebar-toggle, From 1f6152d11e28bc94ab86906ae07d43967f8c759e Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 20 Nov 2024 17:12:01 -0800 Subject: [PATCH 091/112] [Pallas] Use Pallas cost estimator for flash attention. PiperOrigin-RevId: 698573265 --- jax/_src/pallas/cost_estimate.py | 40 ++++++++++++++++--- jax/experimental/pallas/__init__.py | 1 + .../pallas/ops/tpu/flash_attention.py | 27 +++++++------ tests/pallas/pallas_cost_estimate_test.py | 10 ++--- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 1bcf704b3579..b83c36159555 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -16,9 +16,12 @@ import math from typing import Any, Sequence +import jax from jax._src import core as jax_core -from jax._src.pallas import core as pallas_core +from jax._src import custom_derivatives from jax._src import linear_util as lu +from jax._src import pjit +from jax._src.pallas import core as pallas_core from jax._src.interpreters import partial_eval as pe from jax._src.util import safe_map from jax._src.util import safe_zip @@ -71,22 +74,28 @@ def cost_estimate_jaxpr( bytes_accessed=total_cost.bytes_accessed, ) -def cost_estimate(fun, *args) -> pallas_core.CostEstimate: +def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: """Computes a cost estimate for the given function. Args: fun: The function to compute the cost estimate for. *args: The arguments to the function. Can be jax.ShapeDtypeStruct or jax.Array. + **kwargs: The keyword arguments to the function. Returns: A pallas_core.CostEstimate object containing the cost estimate. """ - wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),)) - avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args] + flattened_args, treedef = jax.tree.flatten(args) + def _partial_fun(*flat_args): + return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs) + wrapped_fun = lu.wrap_init( + lambda *args, **kwargs: (_partial_fun(*args, **kwargs),)) + avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) - input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args) + input_bytes = sum( + math.prod(a.shape) * a.dtype.itemsize for a in flattened_args) output_bytes = sum( math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars) return pallas_core.CostEstimate( @@ -213,3 +222,24 @@ def dot_general_cost_rule(ctx: Context, bytes_accessed=0, ) register_cost_rule(lax.dot_general_p, dot_general_cost_rule) + +# Higher-order primitives +def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_): + del ctx + inner_cost = cost_estimate_jaxpr(jaxpr) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(pjit.pjit_p, _pjit_cost_rule) + +def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): + del ctx + inner_cost = cost_estimate_jaxpr(fun_jaxpr) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 34cb5328f36a..7e6527ad999a 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -30,6 +30,7 @@ from jax._src.pallas.core import no_block_spec as no_block_spec from jax._src.pallas.core import Unblocked as Unblocked from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p from jax._src.pallas.primitives import atomic_add as atomic_add diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 9b122fcc03ef..0cb3d798d09e 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -574,26 +574,23 @@ def _fwd_cost_estimate( q: jax.Array, k: jax.Array, v: jax.Array, + ab: jax.Array | None, + segment_ids: SegmentIds | None, *, + causal: bool, + sm_scale: jax.Array | None, kernel_inputs_specs, kernel_outputs_specs, ) -> pl.CostEstimate | None: - b, h, tq, dqk = q.shape - tk = k.shape[-2] - dv = v.shape[-1] - - # Simplify flop computation to include only matmul operations. - qk_flops = 2 * tq * tk * dqk - av_flops = 2 * tq * tk * dv - per_head_flops = qk_flops + av_flops - flops = b * h * per_head_flops - - transcendentals = b * tq * tk * h + body_cost = pl.estimate_cost( + mha_reference, + q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale + ) input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) return pl.CostEstimate( - flops=flops, - transcendentals=transcendentals, + flops=body_cost.flops, + transcendentals=body_cost.transcendentals, bytes_accessed=input_bytes + output_bytes, ) @@ -790,6 +787,10 @@ def kv_segment_ids_index_map( q, k, v, + ab, + segment_ids, + causal=causal, + sm_scale=sm_scale, kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids), kernel_outputs_specs=out_shape, ), diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py index 74dd150fbc10..fcdeac4cab82 100644 --- a/tests/pallas/pallas_cost_estimate_test.py +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -29,7 +29,7 @@ class PallasCostEstimateTest(jtu.JaxTestCase): def test_exp_add(self): def exp_add(x, y): return jnp.exp(x + y) - cost = cost_estimate.cost_estimate(exp_add, + cost = cost_estimate.estimate_cost(exp_add, jnp.ones(10, dtype=jnp.float32), jnp.ones(10, dtype=jnp.float32)) self.assertEqual(cost.flops, 10) @@ -40,7 +40,7 @@ def test_very_large_matmul(self): def matmul(a, b): return a @ b m, k, n = 400_000, 800_000, 900_000 - cost = cost_estimate.cost_estimate( + cost = cost_estimate.estimate_cost( matmul, jax.ShapeDtypeStruct((m, k), jnp.bfloat16), jax.ShapeDtypeStruct((k, n), jnp.bfloat16)) @@ -52,7 +52,7 @@ def test_batched_matmul(self): def matmul(a, b): return jnp.matmul(a, b) b, m, k, n = 7, 37, 91, 23 - cost = cost_estimate.cost_estimate( + cost = cost_estimate.estimate_cost( matmul, jax.ShapeDtypeStruct((b, m, k), jnp.float32), jax.ShapeDtypeStruct((b, k, n), jnp.float32)) @@ -67,7 +67,7 @@ def test_attention(self): q_len = 64 def attention(q, k, v): return jax.nn.softmax(q @ k.T, axis=-1) @ v - cost = cost_estimate.cost_estimate( + cost = cost_estimate.estimate_cost( attention, jnp.zeros((q_len, qk_dim), dtype=jnp.float32), jnp.zeros((kv_len, qk_dim), dtype=jnp.float32), @@ -85,7 +85,7 @@ def attention(q, k, v): (1, 0), (7, 5), (8, 4), (9, 5) ) def test_integer_pow(self, power, expected_flops_per_element): - cost = cost_estimate.cost_estimate(lambda x: lax.integer_pow(x, power), + cost = cost_estimate.estimate_cost(lambda x: lax.integer_pow(x, power), jnp.ones(10, dtype=jnp.float32)) self.assertEqual(cost.flops, 10 * expected_flops_per_element) self.assertEqual(cost.transcendentals, 0) From 840cf3f7d20ce06861a5c48c684ac9a61009d856 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 17:12:29 -0800 Subject: [PATCH 092/112] [sharding_in_types] Add `pad_p` support to sharding_in_types to handle transpose to slice correctly. PiperOrigin-RevId: 698573396 --- jax/_src/lax/lax.py | 27 +++++++++++++++++-- tests/pjit_test.py | 66 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4a9925ce33a4..9a27460906ab 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4490,6 +4490,25 @@ def _pad_shape_rule(operand, padding_value, *, padding_config): raise ValueError(msg) return result +def _pad_sharding_rule(operand, padding_value, *, padding_config): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _pad_shape_rule(operand, padding_value, + padding_config=padding_config) + mesh = operand.sharding.mesh + new_spec = [] + for op_sh, out_sh, op_spec in safe_zip( + operand.shape, out_shape, operand.sharding.spec): + if (op_sh != out_sh and op_spec is not None and + out_sh % slicing._get_sub_spec_size(mesh, op_spec) != 0): + raise NotImplementedError( + f"padding on sharded dims where out dim ({out_sh}) is not divisble by" + f" mesh axes ({slicing._get_sub_spec_size(mesh, op_spec)}) with spec" + f" ({op_spec}) is not implemented.") + new_spec.append(op_spec) + return NamedSharding(mesh, P(*new_spec)) + + def _pad_transpose(t, operand, padding_value, *, padding_config): if type(t) is ad_util.Zero: t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None @@ -4529,14 +4548,18 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): (operand_bdim,)) return select(mask, x, broadcasted_padding), operand_bdim -pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad') +pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', + sharding_rule=_pad_sharding_rule) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule def _pad_lower(ctx, x, padding_value, *, padding_config): aval_out, = ctx.avals_out low, high, interior = util.unzip3(padding_config) - return [mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)] + out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(pad_p, _pad_lower) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e52c805ef5e6..372d6c334f5d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5287,7 +5287,7 @@ def f(x, y): def test_slice(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(4, 4) + np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @jax.jit @@ -5300,6 +5300,16 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) self.assertIn('@Sharding', f.lower(arr).as_text()) + def g(x): + out = f(x) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) @@ -5308,7 +5318,7 @@ def f(x): def test_squeeze(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(4, 4, 1) + np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @jax.jit @@ -5322,6 +5332,58 @@ def f(x): self.assertIn('@Sharding', f.lower(arr).as_text()) self.assertArraysEqual(out, np.squeeze(np_inp, axis=2)) + def g(x): + out = f(x) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + def test_pad(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + np_inp = np.arange(8.) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @partial(jax.jit, static_argnums=(1, 2)) + def f(x, padding_config, spec): + y = lax.pad(x, 0., padding_config) + self.assertEqual(y.sharding.spec, spec) + return y + + out = f(arr, ((2, 2, 0),), P('x')) + self.assertArraysEqual(out, np.pad(np_inp, 2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertIn('@Sharding', f.lower(arr, ((2, 2, 0),), P('x')).as_text()) + + out = f(arr, ((0, 0, 0),), P('x')) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + f(arr, ((0, 3, 1), ), P('x')) # doesn't crash + + def g(x): + out = f(x, ((2, 2, 0),), P('x')) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + f(arr, ((2, 3, 0), ), None) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + f(arr, ((0, 3, 0), ), None) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) + f(arr, ((4, 4, 1),), None) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 869a53345d1551071ce613d56f1f18cce20837e3 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 20 Nov 2024 17:27:25 -0800 Subject: [PATCH 093/112] [Mosaic TPU] Add bound check for general vector store op. PiperOrigin-RevId: 698577015 --- .../dialect/tpu/transforms/debug_assert_insertion.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc index 5478c64f9944..846e3bbb341f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc @@ -122,6 +122,14 @@ void tpu_strided_store_rule(tpu::StridedStoreOp op) { /*strides=*/op.getStrides()); } +void tpu_vector_store_rule(tpu::VectorStoreOp op) { + // TODO(b/379925823): Take strides into account. + assertIsValidSubwindow( + op, op.getIndices(), + /*window_shape=*/op.getValueToStore().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape()); +} + const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ // TODO: tpu::LoadOp, tpu::StoreOp @@ -133,6 +141,8 @@ const llvm::StringMap &rules() { as_generic_rule(tpu_strided_load_rule)}, {tpu::StridedStoreOp::getOperationName(), as_generic_rule(tpu_strided_store_rule)}, + {tpu::VectorStoreOp::getOperationName(), + as_generic_rule(tpu_vector_store_rule)}, }; return *rules; } From 6568713a046b46c8e7f484f7d1db653e20d3aded Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 20:12:01 -0800 Subject: [PATCH 094/112] [sharding_in_types] Add `concatenate_p` support PiperOrigin-RevId: 698621325 --- jax/_src/lax/lax.py | 20 +++++++++++++++++--- tests/pjit_test.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9a27460906ab..0519fa48f45a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1700,6 +1700,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) + if config.sharding_in_types.value: + return broadcast(scalar_zero, aval.shape, sharding=aval.sharding) return broadcast(scalar_zero, aval.shape) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array @@ -4401,7 +4403,7 @@ def _concatenate_shape_rule(*operands, **kwargs): raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands]))) shapes = [operand.shape[:dimension] + operand.shape[dimension+1:] for operand in operands] - if not shapes[:-1] == shapes[1:]: + if shapes[:-1] != shapes[1:]: msg = ("Cannot concatenate arrays with shapes that differ in dimensions " "other than the one being concatenated: concatenating along " "dimension {} for shapes {}.") @@ -4412,6 +4414,13 @@ def _concatenate_shape_rule(*operands, **kwargs): ex_shape = operands[0].shape return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:] +def _concatenate_sharding_rule(*operands, **kwargs): + if not all(o.sharding == operands[0].sharding for o in operands): + ss = ", ".join(str(o.sharding) for o in operands) + raise TypeError( + f"All operands should have the same sharding. Got shardings {ss}") + return operands[0].sharding + def _concatenate_dtype_rule(*operands, **kwargs): check_same_dtypes('concatenate', *operands) return operands[0].dtype @@ -4452,14 +4461,19 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): raise NotImplementedError # TODO(mattjj) concatenate_p = standard_primitive( - _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate') + _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', + sharding_rule=_concatenate_sharding_rule) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): - return [hlo.concatenate(xs, mlir.i64_attr(dimension))] + aval_out, = ctx.avals_out + out = hlo.concatenate(xs, mlir.i64_attr(dimension)) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(concatenate_p, _concatenate_lower) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 372d6c334f5d..dd1415b680a4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5384,6 +5384,48 @@ def g(x): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) + def test_concatenate(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + np_inp = np.arange(16.).reshape(4, 4) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np.arange(4.).reshape(4, 1), s) + + @partial(jax.jit, static_argnums=2) + def f(x, y, method='jnp'): + if method == 'jnp': + y = jnp.concatenate([x, y], axis=1) + else: + assert method == 'lax' + y = lax.concatenate([x, y], dimension=1) + self.assertEqual(y.sharding.spec, P('x', 'y')) + return y + + out = f(arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) + self.assertIn('@Sharding', f.lower(arr1, arr2).as_text()) + + out = f(arr1, arr2, method='lax') + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) + + with self.assertRaisesRegex( + TypeError, "All operands should have the same sharding"): + arr3 = jax.device_put(np.arange(4.).reshape(4, 1), + NamedSharding(mesh, P('x'))) + f(arr1, arr3) + + def g(x, y): + out = f(x, y) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr1, arr2) + self.assertEqual(out.sharding, s) + + out = jax.jit(jax.grad(g))(arr1, arr2) + self.assertEqual(out.sharding, s) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From e72b449089f6af4ceb18288e36215b3c76e69245 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Wed, 20 Nov 2024 22:45:05 -0800 Subject: [PATCH 095/112] Reverts c04aec9d525dd2e767495e41b98e82dd79315f37 PiperOrigin-RevId: 698654038 --- jaxlib/mosaic/dialect/tpu/tpu.td | 5 +-- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 36 ++++----------- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 44 ++++++------------- 3 files changed, 24 insertions(+), 61 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index de5e3514fc1d..8a4f573bce24 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -654,15 +654,14 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { I32:$amount, Optional:$device_id, // For remote DMAs Optional:$core_id, // For megacore - Optional:$subcore_id, // For the SC vector subcore OptionalAttr:$core_type ); let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`subcore_id` $subcore_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; let builders = [ - // A backward-compatible builder that sets `subcore_id` and `core_type` to nullptr. + // A backward-compatible builder that sets `core_type` to nullptr. OpBuilder<(ins "Value":$semaphore, "Value":$amount, "Value":$device_id, "Value":$core_id)>, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 8586e2a16c8a..3271c0874572 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -844,7 +844,7 @@ void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, Value semaphore, Value amount, Value device_id, Value core_id) { build(builder, state, semaphore, amount, device_id, core_id, - /*subcore_id=*/nullptr, /*core_type=*/nullptr); + /*core_type=*/nullptr); } LogicalResult SemaphoreSignalOp::verify() { @@ -861,39 +861,21 @@ LogicalResult SemaphoreSignalOp::verify() { CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); CoreType target_core_type = getCoreType().value_or(issuing_core_type); - if (getCoreId() == nullptr && getDeviceId() == nullptr && - getSubcoreId() == nullptr) { + if (getCoreId() == nullptr && getDeviceId() == nullptr) { if (target_core_type != issuing_core_type) { - return emitOpError(absl::StrFormat( - "Target core type (%s) must match source core type " - "(%s) when device_id, core_id and subcore_id are not specified", - stringifyCoreType(target_core_type), - stringifyCoreType(issuing_core_type))); + return emitOpError( + absl::StrFormat("Target core type (%s) must match source core type " + "(%s) when device_id and core_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); } } - if (target_core_type == CoreType::kScVectorSubcore && - issuing_core_type != CoreType::kScVectorSubcore && - getSubcoreId() == nullptr) { - return emitOpError( - "Subcore ID must be specified for the SC vector subcore"); - } - if (target_core_type != CoreType::kScVectorSubcore && - getSubcoreId() != nullptr) { - return emitOpError( - "Subcore ID must be specified only for the SC vector subcore"); - } if ((issuing_core_type == CoreType::kTc && - (target_core_type == CoreType::kScScalarSubcore || - target_core_type == CoreType::kScVectorSubcore)) || - ((issuing_core_type == CoreType::kScScalarSubcore || - issuing_core_type == CoreType::kScVectorSubcore) && + target_core_type == CoreType::kScScalarSubcore) || + (issuing_core_type == CoreType::kScScalarSubcore && target_core_type == CoreType::kTc)) { return emitOpError("Signalling between TC and SC is not implemented"); } - if (target_core_type == CoreType::kScVectorSubcore && - (getCoreId() != nullptr || getDeviceId() != nullptr)) { - return emitOpError("Signalling remote SC vector subcores is not supported"); - } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 27a886ebeb7e..fd68c9e6c95e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -15,21 +15,19 @@ limitations under the License. // We need to keep some extra headers for the code in tpu_passes.h.inc. -#include #include // IWYU pragma: keep #include #include #include -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "absl/strings/str_format.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" @@ -45,7 +43,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 4; +constexpr int kVersion = 3; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -88,37 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) { LogicalResult semaphore_signal_rule(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. - // Added subcore_id in version 4. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. - op->setAttr( - OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0, 0})); + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. - op->setAttr( - OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0})); - } - return op->emitError("Unexpected operand count in tpu.semaphore_signal"); - } else if (version < 4) { - ArrayRef operand_segment_sizes = - op->getAttrOfType( - OpTrait::AttrSizedOperandSegments< - SemaphoreSignalOp>::getOperandSegmentSizeAttr()); - if (operand_segment_sizes.size() != 4) { - return op->emitError(absl::StrFormat( - "Expected operand count to be 4 in tpu.semaphore_signal. Got %d", - operand_segment_sizes.size())); + // Hardcoding that one optional value is device_id, not core_id. This + // could misinterpret sem_signals where core_id is specified, but + // device_id isn't. + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); + } else { + return op->emitError("Unexpected operand count in tpu.semaphore_signal"); } - SmallVector new_operand_segment_sizes( - operand_segment_sizes.begin(), operand_segment_sizes.end()); - new_operand_segment_sizes.push_back(0); - op->setAttr(OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), - new_operand_segment_sizes)); } return success(); } From f18df8f39cfa9471449e6c66a5b765e17f10c90d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 21 Nov 2024 03:12:51 -0800 Subject: [PATCH 096/112] [pallas:mosaic_gpu] Pulled `delay_release` into `emit_pipeline` The implementation exactly matches the one we have in the lowering. PiperOrigin-RevId: 698713343 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 33 +++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 90c00765e8b1..feb7f1af6301 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -125,20 +125,41 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( - body, + body: Callable[..., None], *, grid: pallas_core.StaticGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, + delay_release: int = 0, ): - """Creates a function to emit a manual pipeline within a Pallas kernel.""" + """Creates a function to emit a manual pipeline within a Pallas kernel. + + Args: + body: The pipeline body. + grid: The grid to use for the pipeline. + in_specs: The block specs for the inputs. + out_specs: The block specs for the outputs. + max_concurrent_steps: The maximum number of sequential stages that are + active concurrently. Defaults to 1. + delay_release: The number of steps to wait before reusing the input/output + references. Defaults to 0, and must be strictly smaller than + ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you + don't await the WGMMA in the body. + """ num_steps = math.prod(grid) + if max_concurrent_steps <= delay_release: + raise ValueError( + "max_concurrent_steps must be greater than delay_release, but" + f" {max_concurrent_steps=}, {delay_release=}" + ) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to - # reduce the size of the allocated buffers below. + # reduce the size of the refs allocated in SMEM. if max_concurrent_steps > num_steps: max_concurrent_steps = num_steps + delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)): @@ -208,7 +229,7 @@ def loop_body(step, carry): gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. gpu_primitives.wait_smem_to_gmem( - max_concurrent_steps - 1, wait_read_only=True + max_concurrent_steps - (1 + delay_release), wait_read_only=True ) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): @@ -245,10 +266,10 @@ def loop_body(step, carry): predicate=lax.bitwise_or(slices_changed, is_last_step), ) - fetch_step = step + max_concurrent_steps + fetch_step = step + (max_concurrent_steps - delay_release) fetch_slot = slot # (x + y) % y == x % y jax.lax.cond( - fetch_step < num_steps, + lax.bitwise_and(fetch_step >= delay_release, fetch_step < num_steps), lambda: map( lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref), in_brefs, From 1bc9df429d87920bdbbf874e84a63fbe3111e27d Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Thu, 21 Nov 2024 05:24:38 -0800 Subject: [PATCH 097/112] Integrate LLVM at llvm/llvm-project@33fcd6acc755 Updates LLVM usage to match [33fcd6acc755](https://github.com/llvm/llvm-project/commit/33fcd6acc755) PiperOrigin-RevId: 698742870 --- jax/experimental/mosaic/gpu/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index b716456eceb3..0ce1140cfa07 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -296,6 +296,12 @@ def globaltimer(kind: Literal["low", "high"] | None = None): def bytewidth(ty: ir.Type): + # The actual width of TF32 is 19 bits. However, sinc we need to treat it as + # 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream + # MLIR, but it changed in + # https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd. + if ir.FloatTF32Type.isinstance(ty): + return 4 if ir.IntegerType.isinstance(ty): return ir.IntegerType(ty).width // 8 if ir.FloatType.isinstance(ty): From 0831e2e3401dfde3b12e407cb4c366b420b16348 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 20 Nov 2024 20:50:37 -0800 Subject: [PATCH 098/112] [shape_poly] Adding shape polymorphism support for the state primitives. --- benchmarks/shape_poly_benchmark.py | 3 +- jax/_src/core.py | 64 ++++++++++++++++++++++++++++++ jax/_src/numpy/lax_numpy.py | 61 +--------------------------- jax/_src/state/indexing.py | 8 ++-- tests/shape_poly_test.py | 30 +++++++++++++- tests/state_test.py | 2 +- 6 files changed, 100 insertions(+), 68 deletions(-) diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index d26801d8dfe5..d365a6facd90 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -17,7 +17,6 @@ import jax from jax import core -from jax._src.numpy import lax_numpy from jax import export jax.config.parse_flags_with_absl() @@ -76,7 +75,7 @@ def inequalities_slice(state): while state: for _ in range(30): a.scope._clear_caches() - start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b) + start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b) _ = 0 <= slice_size <= b _ = start >= 0 _ = start + slice_size <= b diff --git a/jax/_src/core.py b/jax/_src/core.py index cbf3282fb2cc..faf33f00bbf9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2047,6 +2047,70 @@ def dimension_as_value(d: DimSize): if hasattr(d, "dimension_as_value"): return d.dimension_as_value() return operator.index(d) +def canonicalize_slice( + s: slice, + axis_size: DimSize + ) -> tuple[DimSize, DimSize, DimSize]: + """Computes the start index, step, and size of the slice `x[s]`. + + This is similar to `s.indices(axis_size)`, except that it returns + `(start, step, size)`, and it works when the slice and/or the + `axis_size` are symbolic. + + See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding + """ + def convert_to_index(d: DimSize) -> DimSize: + # Convert np.array and jax.Array to int, leave symbolic dimensions alone + try: + return operator.index(d) + except: + return d + + # Must resolve statically if step is {<0, ==0, >0} + step = convert_to_index(s.step) if s.step is not None else 1 + try: + if step == 0: + raise ValueError("slice step cannot be zero") + step_gt_0 = (step > 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the step ({step}) must " + + f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") + + def clamp_index(i: DimSize, which: str): + try: + i_ge_0 = (i >= 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the {which} ({i}) must " + + f"be resolved statically if it is >= 0.\nDetails: {e}") + if i_ge_0: + if step_gt_0: + return min_dim(axis_size, i) + else: + return min_dim(axis_size - 1, i) + else: + if step_gt_0: + return max_dim(0, axis_size + i) + else: + return max_dim(-1, axis_size + i) + + if s.start is None: + start = 0 if step_gt_0 else axis_size - 1 + else: + start = clamp_index(convert_to_index(s.start), "start") + + if s.stop is None: + stop = axis_size if step_gt_0 else -1 + else: + stop = clamp_index(convert_to_index(s.stop), "stop") + + gap = step if step_gt_0 else - step + distance = (stop - start) if step_gt_0 else (start - stop) + slice_size = max_dim(0, distance + gap - 1) // gap + return start, step, slice_size + + class SomeTracer: __slots__ = () def __repr__(self): return "[dynamic]" diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 898e4255dd8e..5f380fad902c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], "arrays within JIT compiled functions).") raise IndexError(msg) - start, step, slice_size = _preprocess_slice(i, x_shape[x_axis]) + start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis]) slice_shape.append(slice_size) if core.definitely_equal(step, 1): @@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx): idx = tuple(idx) + colons return idx -def _preprocess_slice( - s: slice, - axis_size: core.DimSize - ) -> tuple[core.DimSize, core.DimSize, core.DimSize]: - """Computes the start index, step, and size of the slice `x[s]`.""" - # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding - # "this is harder to get right than you may think" - # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275) - def convert_to_index(d: DimSize) -> DimSize: - # Convert np.array and jax.Array to int, leave symbolic dimensions alone - try: - return operator.index(d) - except: - return d - - # Must resolve statically if step is {<0, ==0, >0} - step = convert_to_index(s.step) if s.step is not None else 1 - try: - if step == 0: - raise ValueError("slice step cannot be zero") - step_gt_0 = (step > 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the step ({step}) must " + - f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") - - def clamp_index(i: DimSize, which: str): - try: - i_ge_0 = (i >= 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the {which} ({i}) must " + - f"be resolved statically if it is >= 0.\nDetails: {e}") - if i_ge_0: - if step_gt_0: - return core.min_dim(axis_size, i) - else: - return core.min_dim(axis_size - 1, i) - else: - if step_gt_0: - return core.max_dim(0, axis_size + i) - else: - return core.max_dim(-1, axis_size + i) - - if s.start is None: - start = 0 if step_gt_0 else axis_size - 1 - else: - start = clamp_index(convert_to_index(s.start), "start") - - if s.stop is None: - stop = axis_size if step_gt_0 else -1 - else: - stop = clamp_index(convert_to_index(s.stop), "stop") - - gap = step if step_gt_0 else - step - distance = (stop - start) if step_gt_0 else (start - stop) - slice_size = core.max_dim(0, distance + gap - 1) // gap - return start, step, slice_size - @export def blackman(M: int) -> Array: diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 538f3f8e4888..2da93e3d8e80 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -46,11 +46,11 @@ def __post_init__(self): @property def is_dynamic_start(self): - return not isinstance(self.start, int) + return not core.is_dim(self.start) @property def is_dynamic_size(self): - return not isinstance(self.size, int) + return not core.is_dim(self.size) def tree_flatten(self): # If `start` is statically known, we treat it as static information @@ -72,10 +72,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice: @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: - start, stop, step = slc.indices(size) + start, step, size = core.canonicalize_slice(slc, size) if step < 1: raise ValueError(f"slice must have a step >= 1 (found: {step})") - return cls(start, max((stop - start + step - 1) // step, 0), step) + return cls(start, size, step) def dslice( diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index eda4c4309960..668907ffee27 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -48,6 +48,9 @@ from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.state import discharge +from jax._src.state import primitives as ref_primitives + import numpy as np config.parse_flags_with_absl() @@ -2062,6 +2065,31 @@ def test_vmap_error(self): polymorphic_shapes=["b, ...", "c, ...", None]) + @jtu.parameterized_filterable( + kwargs=[ + dict(slc=slc) + for slc in [ + slice(None, None, None), + slice(2, 5), + ] + ]) + def test_stateful(self, slc: slice): + w, = export.symbolic_shape("w", constraints=["w >= 3"]) + def f(x_ref): + ones = jnp.ones_like(x_ref)[slc] + ref_primitives.ref_addupdate(x_ref, slc, ones) + x1 = ref_primitives.ref_get(x_ref, slc) + x2 = x1 + ones + ref_primitives.ref_set(x_ref, slc, x2) + + exp = export.export(jax.jit(discharge.run_state(f)))( + jax.ShapeDtypeStruct((w,), dtype=_f32)) + x = np.ones((32,), dtype=_f32) + expected = np.copy(x) + expected[slc] = 3. + self.assertAllClose(exp.call(x), expected) + + # List containing either harnesses, or lists of harnesses _POLY_SHAPE_TEST_HARNESSES = [ PolyHarness("add", "", @@ -3603,7 +3631,7 @@ def test_harness(self, harness: PolyHarness): not harness.polymorphic_shapes[0].endswith("...") and jtu.test_device_matches(["tpu"])): raise unittest.SkipTest( - "Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.") + "Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.") config_flags = harness.override_jax_config_flags # Update this here rather than in harness object because vmap_random_gamma is derived diff --git a/tests/state_test.py b/tests/state_test.py index c8458742619d..44caded0ca64 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -752,7 +752,7 @@ def f(a_ref, b_ref): lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) - prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) + prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns) self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr)) self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr)) From 7d7a0fa249c7d42dfa11d492fb62b4d1909fa628 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 21 Nov 2024 07:25:15 -0800 Subject: [PATCH 099/112] Run the TPU workflow on new self-hosted runners We are not able to run the TPU workflows because of no active runners (https://github.com/jax-ml/jax/actions/runs/11879479226/job/33101456081). So this adds the new self-hosted runners to the TPU workflow to fix this issue. The v3 type is disabled as we do not have that available yet. PiperOrigin-RevId: 698772505 --- .github/workflows/cloud-tpu-ci-nightly.yml | 49 ++++++++++++---------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index a5fac5ebdbc3..16c0751f40f8 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -13,7 +13,7 @@ name: CI - Cloud TPU (nightly) on: schedule: - - cron: "0 14 * * *" # daily at 7am PST + - cron: "* */2 * * *" # Run every 2 hours workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. @@ -26,15 +26,18 @@ jobs: matrix: jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - {type: "v3-8", cores: "4"}, - {type: "v4-8", cores: "4"}, - {type: "v5e-8", cores: "8"} + # {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available + # {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] + python-version: ["3.10"] name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20240722 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} - runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"] + PYTHON: python${{ matrix.python-version }} + runs-on: ${{ matrix.tpu.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" timeout-minutes: 120 defaults: run: @@ -46,37 +49,37 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install JAX test requirements run: | - pip install -U -r build/test-requirements.txt - pip install -U -r build/collect-profile-requirements.txt + $PYTHON -m pip install -U -r build/test-requirements.txt + $PYTHON -m pip install -U -r build/collect-profile-requirements.txt - name: Install JAX run: | - pip uninstall -y jax jaxlib libtpu + $PYTHON -m pip uninstall -y jax jaxlib libtpu if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then - pip install .[tpu] \ + $PYTHON -m pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - pip install --pre libtpu \ + $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $PYTHON -m pip install --pre libtpu \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $PYTHON -m pip install requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. - pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $PYTHON -m pip install requests else echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" exit 1 fi - python3 -c 'import sys; print("python version:", sys.version)' - python3 -c 'import jax; print("jax version:", jax.__version__)' - python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' - strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on' - python3 -c 'import jax; print("libtpu version:", + $PYTHON -c 'import sys; print("python version:", sys.version)' + $PYTHON -c 'import jax; print("jax version:", jax.__version__)' + $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' + strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' + $PYTHON -c 'import jax; print("libtpu version:", jax.lib.xla_bridge.get_backend().platform_version)' - name: Run tests env: @@ -84,14 +87,14 @@ jobs: PY_COLORS: 1 run: | # Run single-accelerator tests in parallel - JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ + JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ --maxfail=20 -m "not multiaccelerator" tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. - TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \ + TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \ tests/pallas/tpu_pallas_test.py::PallasCallPrintTest # Run multi-accelerator across all chips - python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests + $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests - name: Send chat on failure # Don't notify when testing the workflow from a branch. if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }} From bf0150bb22b2ed7986adf3762cb1bc555ed3fee8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 21 Nov 2024 08:20:41 -0800 Subject: [PATCH 100/112] [JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculating module hash. PiperOrigin-RevId: 698789020 --- jax/_src/cache_key.py | 3 +++ tests/cache_key_test.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6e025653b81d..324fa85f81ed 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -21,6 +21,7 @@ from typing import cast as type_cast from jax._src import config +from jax._src.lib import version as jaxlib_version from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -225,6 +226,8 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_dump_hlo_as_long_text = False debug_options.xla_dump_disable_metadata = False debug_options.xla_dump_hlo_pipeline_re = "" + if jaxlib_version > (0, 4, 35): + debug_options.xla_gpu_experimental_autotune_cache_mode = 0 # Optional way to specify the cuda install path to be used by the compiler. # This could possibly affect the cuda version compiled with, but this should # already be included in the platform information (and might not be reflected diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 00925c5f7dfc..8f9c5d0e8b82 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -31,6 +31,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.mesh import Mesh from jax._src.partition_spec import PartitionSpec as P @@ -68,6 +69,8 @@ def test_serialized_compile_options(self): debug_options.xla_dump_hlo_as_long_text = True debug_options.xla_dump_disable_metadata = True debug_options.xla_dump_hlo_pipeline_re = "xyzzy" + if jaxlib_version > (0, 4, 35): + debug_options.xla_gpu_experimental_autotune_cache_mode = 2 hash2 = self.get_hashed_value( cache_key._hash_serialized_compile_options, compile_options ) From 1e6654a0314cc067a6f257dd4f5c5a5a5d409f39 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 21 Nov 2024 09:08:23 -0800 Subject: [PATCH 101/112] Fix cron schedule to run past minute 0 every 2nd hour In the previous schedule, we were running at every minute at every 2nd hour. PiperOrigin-RevId: 698804124 --- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 16c0751f40f8..4ac167bd37c1 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -13,7 +13,7 @@ name: CI - Cloud TPU (nightly) on: schedule: - - cron: "* */2 * * *" # Run every 2 hours + - cron: "0 */2 * * *" # Run every 2 hours workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. From 1d2dc17e5f226db7de2a8996c4a2d3bef4c8a0f6 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 21 Nov 2024 09:49:35 -0800 Subject: [PATCH 102/112] [mgpu] Pointwise op can handle LHS splats. PiperOrigin-RevId: 698818035 --- .../mosaic/gpu/fragmented_array.py | 34 ++++++++++++++++++- tests/mosaic/gpu_test.py | 23 +++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2b985ff5c9b8..e1ee37f3d24d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -623,6 +623,38 @@ def to_layout(self, new_layout: FragmentedLayout): ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): + if isinstance(self.layout, WGSplatFragLayout): + # Find either the largest operand or an operand that has a + # concrete layout base the layout computation of that. + widest_idx = None + for i, o in enumerate(other): + if not isinstance(o, FragmentedArray): + continue + elif not isinstance(o.layout, WGSplatFragLayout): + widest_idx = i + break + elif not o.layout.can_broadcast_to(self.layout.shape): + # Note: equal shapes can be broadcast to each other. Using + # the negation we make sure to only consider strictly larger + # shapes so that we don't end up ping ponging between equal + # shapes. + widest_idx = i + + if widest_idx is not None: + # We need to retain the order of arguments that the op + # expects. + def _op(wide_o, self_o, *args): + pre_wide = args[:widest_idx - 1] + post_wide = args[widest_idx - 1:] + return op(self_o, *pre_wide, wide_o, *post_wide) + return other[widest_idx]._pointwise( + _op, + self, + *other[:widest_idx], + *other[widest_idx + 1:], + output_is_signed=output_is_signed, + ) + other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): @@ -642,7 +674,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): o.registers.flat[0], shape=self.shape, layout=self.layout, - is_signed=self.is_signed, + is_signed=o.is_signed, ) else: if self.layout != o.layout: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ab2a00c730d6..87dc2c452041 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1489,6 +1489,29 @@ def kernel(ctx, dst, _): )() np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) + + def test_splat_binary_ops(self): + def kernel(ctx, src, dst, _): + f32 = ir.F32Type.get() + pi_arr = mgpu.FragmentedArray.load_strided(src) + assert isinstance(pi_arr.layout, mgpu.WGStridedFragLayout) + pi_scalar = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) + pi_splat = mgpu.FragmentedArray.splat(pi_scalar, ()) + assert isinstance(pi_splat.layout, mgpu.WGSplatFragLayout) + pi_arr_sq = pi_arr * pi_splat.broadcast(pi_arr.shape) + assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) + pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq + assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) + (pi_arr_sq + pi_arr_cube).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + inp = jnp.ones_like(out_shape) * 3.14 + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, () + )(inp) + np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32)) + + @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) def test_strided_load_store(self, in_shape): def kernel(ctx, *args): From 2178ed2fa42eeb7f609369d56d90950af60d25ca Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Thu, 21 Nov 2024 09:49:48 -0800 Subject: [PATCH 103/112] [pallas] Add more test cases for Triton bitcast_convert_type lowering rule. PiperOrigin-RevId: 698818103 --- tests/pallas/ops_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 9f0b9aef5af3..d7c1bac5dc61 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1941,9 +1941,13 @@ def kernel(x_ref, out_ref): @parameterized.parameters( (jnp.float16, jnp.float16), # Noop - (jnp.int16, jnp.float16), (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), ) def test_bitcast_convert_type(self, in_dtype, out_dtype): if jtu.test_device_matches(["tpu"]): From 96c012990de86ccb0eb815a11ae4e2c337802794 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 21 Nov 2024 10:32:37 -0800 Subject: [PATCH 104/112] Fix false positive `debug_nans` error caused by NaNs that are properly handled in `jax.scipy.stats.gamma` As reported in https://github.com/jax-ml/jax/issues/24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs. Fixes https://github.com/jax-ml/jax/issues/24939 PiperOrigin-RevId: 698833589 --- jax/_src/scipy/stats/gamma.py | 5 +++-- tests/scipy_stats_test.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index f410d08e4f3d..4343c080251c 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -51,12 +51,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) - :func:`jax.scipy.stats.gamma.logsf` """ x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale) + ok = lax.ge(x, loc) one = _lax_const(x, 1) - y = lax.div(lax.sub(x, loc), scale) + y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + return jnp.where(ok, log_probs, -jnp.inf) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index f02ed0fc04bb..88a126c284a7 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -543,6 +543,13 @@ def testGammaLogPdfZero(self): self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) + def testGammaDebugNans(self): + # Regression test for https://github.com/jax-ml/jax/issues/24939 + with jax.debug_nans(True): + self.assertAllClose( + osp_stats.gamma.pdf(0.0, 1.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0, 1.0) + ) + @genNamedParametersNArgs(4) def testGammaLogCdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) From 1efef6bf6b3af91b91fb601e6302b4f17db739e0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 21 Nov 2024 11:38:04 -0800 Subject: [PATCH 105/112] [pallas:mosaic_gpu] `emit_pipeline` now correctly supports `BlockSpec`s in GMEM This is necessary to replace the pipelining logic in the lowering with `emit_pipeline`. PiperOrigin-RevId: 698858380 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 49 ++++++++++++++++++-------- tests/pallas/mosaic_gpu_test.py | 33 +++++++++++++++++ 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index feb7f1af6301..9fcca6acdacc 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -46,7 +46,16 @@ class BufferedRef: spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True}) is_index_invariant: bool = dataclasses.field(metadata={"static": True}) gmem_ref: pallas_core.AbstractMemoryRef - smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape] + # ``None`` if the ref is pinned to GMEM; otherwise, has shape + # [num_slots, *spec.block_shape]. + smem_ref: pallas_core.AbstractMemoryRef | None + + def get_ref_for_slot( + self, slot: int | jax.Array + ) -> pallas_core.AbstractMemoryRef: + if self.smem_ref is None: + return self.gmem_ref + return self.smem_ref.at[slot] def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: index_map = self.spec.index_map @@ -59,6 +68,9 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: ) def copy_in(self, slot, grid_indices, barrier_ref): + if not _in_smem(self.spec): + return + assert self.smem_ref is not None gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_gmem_to_smem( self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands @@ -67,6 +79,9 @@ def copy_in(self, slot, grid_indices, barrier_ref): ) def copy_out(self, slot, grid_indices, predicate=None): + if not _in_smem(self.spec): + return + assert self.smem_ref is not None gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_smem_to_gmem( self.smem_ref.at[slot], @@ -88,8 +103,8 @@ def _uses_arguments( def _is_index_invariant( spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid ) -> bool: - index_map = spec.index_map - assert index_map is not None + if (index_map := spec.index_map) is None: + return True return not any(_uses_arguments(index_map, len(grid))) @@ -105,6 +120,10 @@ def _inc_grid_by_1( return tuple(reversed(next_indices)) +def _in_smem(spec: pallas_core.BlockSpec) -> bool: + return spec.memory_space in (None, gpu_core.SMEM) + + # ``pl.Slice`` uses a different pytree encoding, depending on whether the # start/size are static or dynamic. This leads to pytree structure mismatch # in the pipeline body. So, we define a different ``Slice`` class below. @@ -166,6 +185,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): if any( spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore for idx in range(1, len(grid) + 1) + if spec.block_shape is not None ): raise NotImplementedError( f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block" @@ -174,14 +194,12 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( - map( - lambda spec, ref: gpu_core.SMEM( - (max_concurrent_steps, *spec.block_shape), # type: ignore - ref.dtype, - ), - it.chain(in_specs, out_specs), - gmem_refs, - ), + [ + gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore + if _in_smem(spec) + else None + for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs) + ], [len(in_specs)], ) return pl.run_scoped( @@ -194,7 +212,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): out_smem_refs=out_smem_refs, barrier_ref=gpu_core.Barrier( # TODO(slebedev): Change this to arrive only once. - len(in_specs), + sum(map(_in_smem, in_specs)), num_barriers=max_concurrent_steps, ), ) @@ -233,9 +251,10 @@ def loop_body(step, carry): ) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body( - *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs)) - ) + body(*( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + )) if not all(bref.is_index_invariant for bref in out_brefs): gpu_primitives.commit_smem() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index a4bbc67ee14f..110d83bd992b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1186,6 +1186,39 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_nested_emit(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + grid=(), + )(x_gmem, o_gmem) + + def nested_kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def nested_kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_grid_invariant_output(self): num_steps = 4 From f3e7e6829adae587e60a536a45852a9014389ab6 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 21 Nov 2024 12:17:28 -0800 Subject: [PATCH 106/112] Remove unneeded dependency from rocm_plugin_extension. PiperOrigin-RevId: 698872849 --- jaxlib/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 8c402cfcefe8..987fe24a8008 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -243,7 +243,6 @@ pybind_extension( "@local_config_rocm//rocm:rocm_headers", "@nanobind", "@xla//third_party/python_runtime:headers", - "@xla//xla:status", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", From f899d515354d19801f631b6c096e9db075ac820d Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 21 Nov 2024 13:28:30 -0800 Subject: [PATCH 107/112] [Mosaic TPU] Fold sublane offset to indices when storing to untiled ref. This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]). PiperOrigin-RevId: 698896373 --- .../dialect/tpu/transforms/infer_vector_layout.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index d84e4b883172..c0b2c6c96e7e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1640,14 +1640,14 @@ class VectorLayoutInferer { // Since it is untiled, we can store to any arbitrary address which // means the sublane offset can be any value and we can fold it to // 2nd minor index. - // TODO(jevinjiang): We can fold the sublane offset into the 2nd minor - // index. But we need to handle negative index in lower-to-llo. For - // now, we just force the sublane offset to be 0. + auto prev_store_layout = getLayout(op.getValueToStore()); + TPU_CHECK_OP(prev_store_layout.has_value(), "missing vector layout"); + offsets[0] = prev_store_layout->offsets()[0].value_or(0); if (offsets[1].value_or(0) >= tiling[1]) { offsets[1] = 0; } - store_layout = VectorLayout(bitwidth, {0, offsets[1]}, - nativeTiling(bitwidth), ImplicitDim::kNone); + store_layout = VectorLayout(bitwidth, offsets, nativeTiling(bitwidth), + ImplicitDim::kNone); } else { store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]}, ImplicitDim::kNone); From 26443bbd6696ab296408b808fdc9f3974c4cfa3b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 21 Nov 2024 14:25:39 -0800 Subject: [PATCH 108/112] Update XLA dependency to use revision http://github.com/openxla/xla/commit/85360d67ffc0a6d6923605b848de12ec204ca336. PiperOrigin-RevId: 698915433 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 327e4ca422ac..46f71523be05 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e763f8875b0a9bfca876be9b02c874979e55422a" -XLA_SHA256 = "7b6a33894c6510167cac6e0ab7a6331ffa84e7fcaaa1d3b1c462ec5ecacb0682" +XLA_COMMIT = "85360d67ffc0a6d6923605b848de12ec204ca336" +XLA_SHA256 = "7afa7e599adf7b1a636ea9e55419c253a115ef27217ec862ca8a03cef1abd11a" def repo(): tf_http_archive( From 344d0d998d682ffae5e35ad0ed4a3de50e51940c Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 21 Nov 2024 15:42:34 -0800 Subject: [PATCH 109/112] [Pallas] Add readme page for debugging tips. PiperOrigin-RevId: 698939951 --- jax/experimental/pallas/g3doc/debugging.md | 207 +++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 jax/experimental/pallas/g3doc/debugging.md diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md new file mode 100644 index 000000000000..40b109d102d5 --- /dev/null +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -0,0 +1,207 @@ +# Debugging Pallas + + + + + +[TOC] + +This document contains a collection of tips and tricks for debugging Pallas +programs. For any specific requests or ideas for improvement, please create +a ticket on https://github.com/jax-ml/jax/issues. + +## Debugging Tools + +### Interpret (HLO) Mode + +Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. + +Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. + +### debug_print + +The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation. + +For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option. + + +```python +kernel = pl.pallas_call(...) +compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) +result = compiled_kernel(x) +``` + +### Runtime Asserts + +Checkify can be used to insert runtime asserts, nan checks, out of bounds errors, etc. inside of a kernel. +Pallas implements two options for assertions: a *hard assert* which will crash the TPU if failed, and a *functionalized assertion* which will simulate a runtime assertion that can be thrown +as a Python error after the kernel has successfully executed. + +#### Hard assertion + +Hard assertions can be inserted with `checkify.check` +and running your program with the `--jax_pallas_enable_runtime_assert` flag. + +Your code will look like the following: + +```python +from jax.experimental import checkify + +def kernel(...): + checkify.check(x > y, "Check x > y failed") # Will halt if x <= y +``` + +This will print a relatively lengthy dump which resembles the following: + +``` +E1001 15:22:33.275768 4353 real_program_continuator.cc:1350] 0x0x0_TC0: [Physical location: dldgr4:pe1:1] generic::internal: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x169 (from TensorCoreSequencer:1:0x213): Check x > y failed HLO: main; HLO computation: main.3 +``` + +The benefit of a hard assertion is that it is guaranteed to either pass or +halt the TPU. The kernel will never proceed past the assertion if it fails. +However, the downside is that if the assertion fails you will +likely have to restart the program in order to run any other TPU operations, +and there is no Python error thrown that can be caught. + +#### Functionalized assertion +Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op like so: + +```python +from jax.experimental import checkify + +def kernel(...): + checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y + +kernel = pl.pallas_call(...) +checkified_kernel = checkify.checkify(kernel, + errors=checkify.all_checks) +error, result = checkified_kernel(x) +error.throw() +``` + +This will throw a Python error if any checks failed, such as if a NaN occurred +or if an out-of-bounds index was accessed. + +The benefit of a functionalized assert is that it will throw Python errors +that can be caught, and it will not interfere with downstream TPU operations. +However, it requires the kernel to successfully complete, meaning if your +error would have caused a TPU crash, the crash would still happen and +the error would not be thrown. + + +### Dumping Jaxprs + +Passing in `debug=True` into `pl.pallas_call` will print out the Jaxpr of the kernel as well as the lowered Mosaic code. + +```python +def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +x = jnp.ones((8, 128), dtype=jnp.float32) +pl.pallas_call( + kernel, + out_shape=jax.ShapeDTypeStruct((8, 128), jnp.float32) + debug=True, + name="my_call", +)(x, x) +``` + +This will output: + +``` +The kernel jaxpr for the pallas_call my_call for kernel function kernel at ...:1000: +{ lambda ; a:MemRef{float32[8,128]} b:MemRef{float32[8,128]} c:MemRef{float32[8,128]}. let + d:f32[8,128] <- a[:,:] + e:f32[8,128] <- b[:,:] + f:f32[8,128] = add d e + c[:,:] <- f + in () } + +The Mosaic module for the pallas_call my_call for kernel function kernel at ...:1000: +module { + func.func @main(%arg0: memref<8x128xf32, #tpu.memory_space>, %arg1: memref<8x128xf32, #tpu.memory_space>, %arg2: memref<8x128xf32, #tpu.memory_space>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} { + %c0 = arith.constant 0 : index + %c0_0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0_0] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + %c0_1 = arith.constant 0 : index + %c0_2 = arith.constant 0 : index + %1 = vector.load %arg1[%c0_1, %c0_2] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + %2 = arith.addf %0, %1 : vector<8x128xf32> + %c0_3 = arith.constant 0 : index + %c0_4 = arith.constant 0 : index + %3 = vector.load %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + vector.store %2, %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + return + } +} +``` + +### Dumping Mosaic Passes + +Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosaic if you are running into errors that are originating from the Mosaic compiler to see what code is actually being generated. + +Passing the `--xla_mosaic_dump_to=` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge. + +### Static Verification + +The static verification tool can be used to automatically detect race conditions in distributed kernels. +Because this tool uses formal verification, it is best used for small kernels (<=2 devices). + +Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=`, +which will output a Promela dump file. Afterwards, the dump file can be +analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run: + +``` +spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan +``` + + + +## Useful Command line flags + +* OOB Checks: `--xla_mosaic_on_device_checks=bounds` +* Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` + +* Dump Mosaic: `--xla_mosaic_dump_to=` +* Enable trace markers in XProf: `--xla_enable_transpose_trace` + +## Common Errors + +### INTERNAL Mosaic failed to compile TPU Kernel + +`INTERNAL Mosaic failed to compile TPU Kernel: Not implemented X` + +This error means that you hit an unimplemented case in the underlying Mosaic compiler. +Our recommended course of action here is to file a ticket if one does not already +exist for your specific error. + +In some cases, your error may be due to an operation which cannot be implemented +efficiently in the compiler, in which your best course of action is to find a workaround. This +is most commonly seen in `layout` and `shape_cast` errors. The important tip +to remember regarding layouts is that the last 2 dimensions of arrays in Pallas +are physically tiled into registers, so any reshapes, slicing, transposes, etc. +on the last 2 dimensions may trigger a relayout. + + +### VerificationError + +A verification error indicates that Pallas produced invalid code for Mosaic. + +This is a bug in Pallas, so please file a bug under https://github.com/jax-ml/jax/issues. + +### LoweringError + +This is a catch-all error type during Pallas to Mosaic lowering and can have many causes. +In most cases the error message should hint at what is wrong. + +For specific errors: + +* `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod + + From 170718c8d476e6727baf070f66c2ddbd8829f95a Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 21 Nov 2024 17:46:21 -0800 Subject: [PATCH 110/112] Change signature of linearization rules. Give the rule the nonzero tangent pattern up-front. This is needed to make a linearization rule for pjit_p. Also make the rules return the nonzero tangents out, an explicit residual, and a closed tangent function. Add a rule for sin_p to test it out. We still need to figure out how to avoid having to precompute `cos(x)`. I think we need to update our backward pass code. --- jax/_src/interpreters/ad.py | 31 ++++++++++++++++++------------- jax/_src/lax/lax.py | 6 +++++- tests/api_test.py | 2 +- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 91f061fd2210..9fa2fdb9ffbf 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -483,39 +483,44 @@ def to_primal_tangent_pair(self, val): def process_primitive(self, primitive, args, params): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) + tangent_nonzeros = [type(t) is not Zero for t in tangents_in] if all(type(t) is Zero for t in tangents_in): return primitive.bind_with_trace(self.parent_trace, primals_in, params) lin = primitive_linearizations.get(primitive) if lin is None: lin = partial(fallback_linearize_rule, primitive) with core.set_current_trace(self.parent_trace): - primal_out, linearized = lin(*primals_in, **params) + primal_out, tangent_nonzeros_out, residuals, linearized = lin( + tangent_nonzeros, *primals_in, **params) with core.set_current_trace(self.tangent_trace): - tangent_out = linearized(*tangents_in) + tangent_out = linearized(residuals, *tangents_in) if primitive.multiple_results: - return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] + return [maybe_linearize_tracer(self, x, nz, t) + for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)] else: - return maybe_linearize_tracer(self, primal_out, tangent_out) + return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out) -def maybe_linearize_tracer(trace, primal, tangent): - if type(tangent) is Zero: - return primal - else: +def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): + if is_nonzero: + assert not type(tangent) is Zero return LinearizeTracer(trace, primal, tangent) + else: + assert type(tangent) is Zero + return primal -def fallback_linearize_rule(prim, *args, **kwargs): +def fallback_linearize_rule(prim, _, *args, **kwargs): def call_prim(*args_): return prim.bind(*args_, **kwargs) with config.use_direct_linearize(False): out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize( lu.wrap_init(call_prim), *args, **kwargs) - def linearized(*tangents): - tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents)) + def linearized(residuals, *tangents): + tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents)) full_out = [pval.get_known() if pval.is_known() else next(tangents_out) for pval in out_tangents_pvals] assert next(tangents_out, None) is None return full_out - return out_primals, linearized + return out_primals, [True for _ in out_primals], consts, linearized class LinearizeTracer(Tracer): __slots__ = ['primal', 'tangent'] @@ -547,7 +552,7 @@ def to_concrete_value(self): primitive_transposes: dict[core.Primitive, Callable] = {} # transpose rules that internally perform reductions over the given named axes reducing_transposes: dict[core.Primitive, Callable] = {} -primitive_linearizations: dict[core.Primitive, Callable] = {} +primitive_linearizations : dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0519fa48f45a..1099919a6474 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2400,12 +2400,16 @@ def _sin_lowering(ctx, x): return sine(ctx, x) return _nary_lower_hlo(hlo.sine, ctx, x) +def _sin_p_lin(_, x): + cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) + return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule - def _cos_complex(x): # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) # see also _sin_complex diff --git a/tests/api_test.py b/tests/api_test.py index ff7855b68991..a27938eed392 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4818,7 +4818,7 @@ def check_invariant_to_use_direct_linearize(f): self.assertEqual(ans1, ans2) def sin_of_sin(x): - return jnp.sin(jnp.sin(x)) + return lax.sin(lax.sin(x)) check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) From 355589f32b29ab1a2c59b58cef2ede80d4d3f642 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 21 Nov 2024 20:12:21 -0800 Subject: [PATCH 111/112] [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here * Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path) * Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager. * Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them. * scan only allows `xs` where the 0th dim is full replicated i.e. None. PiperOrigin-RevId: 699014167 --- jax/_src/core.py | 8 +++- jax/_src/lax/control_flow/loops.py | 31 ++++++++++----- jax/_src/lax/lax.py | 14 +------ jax/_src/lax/slicing.py | 63 ++++++++++++++++++++++-------- jax/_src/pjit.py | 41 +++++++++++-------- jax/_src/sharding_impls.py | 5 +++ jax/_src/stages.py | 11 ++++-- tests/pjit_test.py | 42 +++++++++++++++++++- 8 files changed, 153 insertions(+), 62 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 2ad4264e9edd..86646faa980b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2263,16 +2263,20 @@ def _map_shaped_array( assert axis is None or aval.shape[axis] == size # TODO: Extend the named shape if axis is None: return aval + sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) + if config.sharding_in_types.value else None) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding) def _unmap_shaped_array( size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: + sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name)) + if config.sharding_in_types.value else None) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding) else: raise TypeError(axis) def _map_dshaped_array( diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d15917b8b1da..76132ccdc99a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -227,6 +227,11 @@ def scan(f, init, xs, length=None): msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err + if (config.sharding_in_types.value and + not all(x.sharding.spec[0] is None for x in xs_flat)): + raise ValueError('0th dimension of all xs should be replicated. Got ' + f'{", ".join(str(x.sharding.spec) for x in xs_flat)}') + if length is not None: try: length = int(length) @@ -250,7 +255,8 @@ def scan(f, init, xs, length=None): if config.disable_jit.value: if length == 0: - raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") + raise ValueError("zero-length scan is not supported in disable_jit() " + "mode because the output type is unknown.") carry = init ys = [] maybe_reversed = reversed if reverse else lambda x: x @@ -424,7 +430,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, num_trips, remainder = 0, length if unroll == 1: xss = xs_ - yss = _map(partial(_empty_array, (length,)), y_avals) + yss = _map(partial(_empty_array, (length,), None), y_avals) else: if remainder: if not reverse: @@ -432,7 +438,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, else: xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals) + yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals) def cond_fun(while_carry): i, _, _ = while_carry @@ -477,8 +483,11 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) -def _empty_array(prefix, aval): - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape)) +def _empty_array(prefix, length_spec, aval): + sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec)) + if config.sharding_in_types.value else None) + return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), + sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True @@ -486,11 +495,13 @@ def _stage_jaxpr(trace, *tracers, jaxpr): params = dict(call_jaxpr=jaxpr) return trace.default_process_primitive(core.closed_call_p, tracers, params) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr + @eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf -def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects +def _stage_jaxpr_abstract_eval(*_, jaxpr): + return jaxpr.out_avals, jaxpr.effects def _prepend_dim_to_aval(sz, aval): - return core.unmapped_aval(sz, core.no_axis_name, 0, aval) + return core.unmapped_aval(sz, None, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): @@ -674,7 +685,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) - ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) + ys_avals = [core.unmapped_aval(length, None, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in itertools.chain(carry_avals, ys_avals)] @@ -1041,7 +1052,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Create residual variables. intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) - ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a) + ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a) for a in ext_avals_mapped] newvar = core.gensym() intensive_res = _map(newvar, intensive_avals) @@ -1119,7 +1130,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, jaxpr.in_avals, [num_consts, num_carry]) carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) - y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a) + y_avals = [core.unmapped_aval(length, None, 0, a) for a in y_avals_mapped] if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1099919a6474..1b84797d630e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4513,18 +4513,8 @@ def _pad_sharding_rule(operand, padding_value, *, padding_config): # change this logic to `return operand.sharding` directly. out_shape = _pad_shape_rule(operand, padding_value, padding_config=padding_config) - mesh = operand.sharding.mesh - new_spec = [] - for op_sh, out_sh, op_spec in safe_zip( - operand.shape, out_shape, operand.sharding.spec): - if (op_sh != out_sh and op_spec is not None and - out_sh % slicing._get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( - f"padding on sharded dims where out dim ({out_sh}) is not divisble by" - f" mesh axes ({slicing._get_sub_spec_size(mesh, op_spec)}) with spec" - f" ({op_spec}) is not implemented.") - new_spec.append(op_spec) - return NamedSharding(mesh, P(*new_spec)) + return slicing._get_sharding_for_varying_out_shape( + out_shape, operand, 'padding') def _pad_transpose(t, operand, padding_value, *, padding_config): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 117c8b655152..c6c85ce4f6a3 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -42,7 +42,6 @@ _input_dtype, standard_primitive, ) -from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike, Shape @@ -1276,23 +1275,33 @@ def _get_sub_spec_size(mesh, sub_spec): return math.prod(mesh.shape[s] for s in sub_spec) return mesh.shape[sub_spec] -def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): - # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, - # change this logic to `return operand.sharding` directly. - out_shape = _slice_shape_rule(operand, start_indices=start_indices, - limit_indices=limit_indices, strides=strides) +def _get_sharding_for_varying_out_shape(out_shape, operand, name): + """Returns a sharding when out_shape may not be the same as operand shape""" mesh = operand.sharding.mesh - new_spec = [] for op_sh, out_sh, op_spec in safe_zip( operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): raise NotImplementedError( - f"slicing on sharded dims where out dim ({out_sh}) is not divisble by" + f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" f" ({op_spec}) is not implemented.") - new_spec.append(op_spec) - return NamedSharding(mesh, P(*new_spec)) + # TODO(yashkatariya): Returning operand.sharding as is may or may not move + # data. So think about how to avoid it which might include creating a new + # mesh? For example: + # mesh = {'x': 4} + # x = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))` + # ys = lax.split(x, [4, 4]) # This will create outputs of shape (4,) + # According to the current logic, ys[0].sharding.spec == P('x') + # which involves data movement. + return operand.sharding + +def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _slice_shape_rule(operand, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + return _get_sharding_for_varying_out_shape(out_shape, operand, 'slicing') def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) @@ -1367,8 +1376,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): mlir.register_lowering(slice_p, _slice_lower) -def _dynamic_slice_shape_rule( - operand, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if operand.ndim != len(start_indices): msg = ("dynamic_slice start_indices must have length equal to the number " @@ -1391,6 +1399,12 @@ def _dynamic_slice_shape_rule( f" got indices {start_indices}") return tuple(lax._merge_dyn_shape(slice_sizes, dyn)) +def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes): + out_shape = _dynamic_slice_shape_rule( + operand, *starts_and_dyn_sizes, slice_sizes=slice_sizes) + return _get_sharding_for_varying_out_shape(out_shape, operand, 'dynamic_slice') + + def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if any(i.dtype != start_indices[0].dtype or @@ -1494,7 +1508,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + sharding_rule=_dynamic_slice_sharding_rule) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1508,7 +1523,10 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): aval_out, = ctx.avals_out if dyn: aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn)) - return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)] + out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower) @@ -1539,6 +1557,14 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): f"scalars, got indices {start_indices}") return operand.shape +def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): + if operand.sharding != update.sharding: + raise TypeError( + "dynamic_update_slice update sharding must be equal to operand" + f" sharding, got update sharding {update.sharding} for operand sharding" + f" {operand.sharding}.") + return operand.sharding + def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): lax.check_same_dtypes("dynamic_update_slice", operand, update) if any(i.dtype != start_indices[0].dtype or @@ -1604,7 +1630,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice') + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -1613,8 +1639,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): def _dynamic_update_slice_lower(ctx, x, update, *start_indices): aval_out, = ctx.avals_out - return [mlir.dynamic_update_slice(ctx, aval_out, x, update, - start_indices=start_indices)] + out = mlir.dynamic_update_slice(ctx, aval_out, x, update, + start_indices=start_indices) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index aff956862753..4f16e0013f25 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -185,16 +185,19 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): - args_flat = map(core.full_lower, args_flat) - core.check_eval_args(args_flat) - out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) - else: - out_flat = pjit_p.bind(*args_flat, **p.params) - compiled = None - profiler = None + # TODO(yashkatariya): Maybe thread this into pjit params like resource_env + # and set the context manager down the stack? + with p.abstract_mesh: + if (core.trace_state_clean() and + not config.debug_key_reuse.value and + not config.data_dependent_tracing_fallback.value): + args_flat = map(core.full_lower, args_flat) + core.check_eval_args(args_flat) + out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) + else: + out_flat = pjit_p.bind(*args_flat, **p.params) + compiled = None + profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if p.params['resource_env'] is None else 'pjit' @@ -330,9 +333,10 @@ def cache_miss(*args, **kwargs): if config.no_tracing.value: raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - outs, out_flat, out_tree, args_flat, jaxpr, \ - attrs_tracked, executable, pgle_profiler = _python_pjit_helper( - fun, jit_info, *args, **kwargs) + + (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable, + pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) + maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, @@ -495,10 +499,10 @@ def trace(*args, **kwargs) -> stages.Traced: donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) + pgle_profiler=None) return stages.Traced( p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) + lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts) wrapped = _cpp_pjit(fun, jit_info) wrapped.lower = lower @@ -534,6 +538,7 @@ class PjitParams(NamedTuple): arg_names: tuple[str, ...] | None num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + abstract_mesh: AbstractMesh def _infer_params_impl( @@ -639,7 +644,9 @@ def _infer_params_impl( attr_token = _attr_token(flat_fun, in_type) - abstract_mesh = get_abstract_mesh(in_type) + abstract_mesh = ( + get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None + else mesh_lib.mesh_context.mesh) with abstract_mesh: jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, dbg, @@ -684,7 +691,7 @@ def _infer_params_impl( ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names if dbg else None, len(consts), - attrs_tracked), args_flat + attrs_tracked, abstract_mesh), args_flat def get_abstract_mesh(in_avals): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index dc4171eec146..8abe58e52a74 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -363,6 +363,11 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) + def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: + if not isinstance(spec, PartitionSpec): + spec = PartitionSpec(*spec) + return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 92c680009c93..b6f3b63d3de4 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,6 +30,7 @@ """ from __future__ import annotations +import contextlib import functools from collections.abc import Sequence from dataclasses import dataclass @@ -716,13 +717,14 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, args_flat=None, arg_names=None, - num_consts: int = 0): + lower_callable, abstract_mesh=contextlib.nullcontext(), + args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info self.fun_name = fun_name self._out_tree = out_tree self._lower_callable = lower_callable + self._abstract_mesh = abstract_mesh self._args_flat = args_flat self._arg_names = arg_names self._num_consts = num_consts @@ -743,7 +745,10 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, self._lower_callable, lowering_platforms=lowering_platforms, lowering_parameters=_private_parameters) try: - lowering = new_callable() + # TODO(yashkatariya): Maybe thread this into pjit params like resource_env + # and set the context manager down the stack? + with self._abstract_mesh: + lowering = new_callable() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args msg = pjit._device_assignment_mismatch_error( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index dd1415b680a4..293026b2b612 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4674,7 +4674,7 @@ def f(x): if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) else: - self.assertEqual(lowered_text.count('@Sharding'), 2) + self.assertEqual(lowered_text.count('@Sharding'), 3) @jax.jit def g(x): @@ -5244,6 +5244,7 @@ def test_shard_map_full_manual(self): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective) return x * y @jax.jit @@ -5268,6 +5269,7 @@ def test_shard_map_dot(self): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') @@ -5426,6 +5428,44 @@ def g(x, y): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) + def test_scan(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + carry = jax.device_put(np.arange(16.).reshape(2, 8), + NamedSharding(mesh, P(None, 'x'))) + arr = jax.device_put(np.arange(128.).reshape(8, 8, 2), + NamedSharding(mesh, P(None, 'x', 'y'))) + + @jax.jit + def f(carry, xs): + def g(carry, x): + self.assertEqual(carry.sharding.spec, P(None, 'x')) + self.assertEqual(x.sharding.spec, P('x', 'y')) + y = carry @ x + self.assertEqual(y.sharding.spec, P(None, 'y')) + z = jax.nn.relu(y) + self.assertEqual(z.sharding.spec, P(None, 'y')) + a = z @ x.T + self.assertEqual(a.sharding.spec, P(None, 'x')) + return a, y + return jax.lax.scan(g, carry, xs) + + activation, mean = f(carry, arr) + self.assertEqual(activation.sharding, NamedSharding(mesh, P(None, 'x'))) + self.assertEqual(mean.sharding, NamedSharding(mesh, P(None, None, 'y'))) + + f.lower(carry, arr).compile()(carry, arr) # doesn't crash + + def g(carry, arr): + out = f(carry, arr) + return jnp.sum(out[0]) + out = jax.jit(jax.grad(g, argnums=(0, 1)))(carry, arr) + self.assertEqual(out[0].sharding, carry.sharding) + self.assertEqual(out[1].sharding, arr.sharding) + + with self.assertRaisesRegex( + ValueError, "0th dimension of all xs should be replicated"): + f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 846697f761a5e6857ecea7fcadf02cb7dd5ff18e Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 22 Nov 2024 10:36:01 -0600 Subject: [PATCH 112/112] Longer timeout for doc render --- .github/workflows/ci-build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 0fd188098ee9..b3f683f89f78 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -139,7 +139,7 @@ jobs: documentation_render: name: Documentation - render documentation runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 20 strategy: matrix: python-version: ['3.10']