diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml
new file mode 100644
index 000000000000..a7cf645b50b3
--- /dev/null
+++ b/.github/workflows/bazel_gpu_rbe.yml
@@ -0,0 +1,39 @@
+name: CI - Bazel GPU tests (RBE)
+
+on:
+ workflow_dispatch:
+ inputs:
+ halt-for-connection:
+ description: 'Should this workflow run wait for a remote connection?'
+ type: choice
+ required: true
+ default: 'no'
+ options:
+ - 'yes'
+ - 'no'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ run_tests:
+ if: github.event.repository.fork == false
+ strategy:
+ matrix:
+ runner: ["linux-x86-n2-16"]
+
+ runs-on: ${{ matrix.runner }}
+ container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
+
+ env:
+ JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: Wait For Connection
+ uses: google-ml-infra/actions/ci_connection@main
+ with:
+ halt-dispatch-input: ${{ inputs.halt-for-connection }}
+ - name: Run Bazel GPU Tests with RBE
+ run: ./ci/run_bazel_test_gpu_rbe.sh
\ No newline at end of file
diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml
index 0fd188098ee9..b3f683f89f78 100644
--- a/.github/workflows/ci-build.yaml
+++ b/.github/workflows/ci-build.yaml
@@ -139,7 +139,7 @@ jobs:
documentation_render:
name: Documentation - render documentation
runs-on: ubuntu-latest
- timeout-minutes: 10
+ timeout-minutes: 20
strategy:
matrix:
python-version: ['3.10']
diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml
index a5fac5ebdbc3..4ac167bd37c1 100644
--- a/.github/workflows/cloud-tpu-ci-nightly.yml
+++ b/.github/workflows/cloud-tpu-ci-nightly.yml
@@ -13,7 +13,7 @@
name: CI - Cloud TPU (nightly)
on:
schedule:
- - cron: "0 14 * * *" # daily at 7am PST
+ - cron: "0 */2 * * *" # Run every 2 hours
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
@@ -26,15 +26,18 @@ jobs:
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
- {type: "v3-8", cores: "4"},
- {type: "v4-8", cores: "4"},
- {type: "v5e-8", cores: "8"}
+ # {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
+ # {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
+ {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
+ python-version: ["3.10"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240722
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
- runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
+ PYTHON: python${{ matrix.python-version }}
+ runs-on: ${{ matrix.tpu.runner }}
+ container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 120
defaults:
run:
@@ -46,37 +49,37 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install JAX test requirements
run: |
- pip install -U -r build/test-requirements.txt
- pip install -U -r build/collect-profile-requirements.txt
+ $PYTHON -m pip install -U -r build/test-requirements.txt
+ $PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
- pip uninstall -y jax jaxlib libtpu
+ $PYTHON -m pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
- pip install .[tpu] \
+ $PYTHON -m pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
- pip install --pre libtpu \
+ $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
+ $PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- pip install requests
+ $PYTHON -m pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
+ $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
- pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
+ $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- pip install requests
+ $PYTHON -m pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
- python3 -c 'import sys; print("python version:", sys.version)'
- python3 -c 'import jax; print("jax version:", jax.__version__)'
- python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
- strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
- python3 -c 'import jax; print("libtpu version:",
+ $PYTHON -c 'import sys; print("python version:", sys.version)'
+ $PYTHON -c 'import jax; print("jax version:", jax.__version__)'
+ $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
+ strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
+ $PYTHON -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
@@ -84,14 +87,14 @@ jobs:
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
- JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
+ JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
- TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
+ TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
- python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
+ $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml
index 942034169e09..8f2029eb9191 100644
--- a/.github/workflows/jax-array-api.yml
+++ b/.github/workflows/jax-array-api.yml
@@ -28,7 +28,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
- ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻
+ ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 10b1fc808970..be9aaebcd615 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.36
* Breaking Changes
+ * This release lands "stackless", an internal change to JAX's tracing
+ machinery. We made trace dispatch purely a function of context rather than a
+ function of both context and data. This let us delete a lot of machinery for
+ managing data-dependent tracing: levels, sublevels, `post_process_call`,
+ `new_base_main`, `custom_bind`, and so on. The change should only affect
+ users that use JAX internals.
+
+ If you do use JAX internals then you may need to
+ update your code (see
+ https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
+ for clues about how to do this). There might also be version skew
+ issues with JAX libraries that do this. If you find this change breaks your
+ non-JAX-internals-using code then try the
+ `config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
+ you need help updating your code then please file a bug.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
@@ -43,6 +58,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
+ * {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
+ return NaN for negative integer inputs, to match the behavior of SciPy from
+ https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
* New Features
@@ -52,12 +70,22 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
+ * Added {func}`jax.numpy.put_along_axis`.
+ * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
+ ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
+ supported on GPU. See {jax-issue}`#24663` for more details.
* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`#24843` for more details.
+* Deprecations
+ * `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
+ use `jax.Array` instead.
+ * `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
+ instead.
+
## jax 0.4.35 (Oct 22, 2024)
* Breaking Changes
diff --git a/README.md b/README.md
index 89fe51212638..b001a8ceeb15 100644
--- a/README.md
+++ b/README.md
@@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.
Using `jit` puts constraints on the kind of Python control flow
the function can use; see
-the [Gotchas
-Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
+the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
for more.
### Auto-vectorization with `vmap`
@@ -349,7 +348,7 @@ Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are
- different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
+ different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package.
@@ -369,7 +368,7 @@ Some standouts:
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is `float64` rather than `float32`.
1. Some transformations, like `jit`, [constrain how you can use Python control
- flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
+ flow](https://jax.readthedocs.io/en/latest/control-flow.html).
You'll always get loud errors if something goes wrong. You might have to use
[`jit`'s `static_argnums`
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
@@ -390,6 +389,7 @@ Some standouts:
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | yes | no | experimental | n/a | no | no |
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
+| Intel GPU | experimental | n/a | n/a | n/a | no | no |
### Instructions
@@ -401,6 +401,7 @@ Some standouts:
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
+| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |
See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
for information on alternative installation strategies. These include compiling
diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py
index d26801d8dfe5..d365a6facd90 100644
--- a/benchmarks/shape_poly_benchmark.py
+++ b/benchmarks/shape_poly_benchmark.py
@@ -17,7 +17,6 @@
import jax
from jax import core
-from jax._src.numpy import lax_numpy
from jax import export
jax.config.parse_flags_with_absl()
@@ -76,7 +75,7 @@ def inequalities_slice(state):
while state:
for _ in range(30):
a.scope._clear_caches()
- start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
+ start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
_ = 0 <= slice_size <= b
_ = start >= 0
_ = start + slice_size <= b
diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh
new file mode 100755
index 000000000000..0c004c584300
--- /dev/null
+++ b/ci/run_bazel_test_gpu_rbe.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one
+# GPU apiece on RBE.
+#
+# -e: abort script if one command fails
+# -u: error if undefined variable used
+# -x: log all commands
+# -o history: record shell history
+# -o allexport: export all functions and variables to be available to subscripts
+set -exu -o history -o allexport
+
+# Source default JAXCI environment variables.
+source ci/envs/default.env
+
+# Clone XLA at HEAD if path to local XLA is not provided
+if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
+ export JAXCI_CLONE_MAIN_XLA=1
+fi
+
+# Set up the build environment.
+source "ci/utilities/setup_build_environment.sh"
+
+# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece).
+echo "Running RBE GPU tests..."
+
+bazel test --config=rbe_linux_x86_64_cuda \
+ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
+ --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
+ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
+ --test_output=errors \
+ --test_env=TF_CPP_MIN_LOG_LEVEL=0 \
+ --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
+ --test_tag_filters=-multiaccelerator \
+ --test_env=JAX_SKIP_SLOW_TESTS=true \
+ --action_env=JAX_ENABLE_X64=0 \
+ --color=yes \
+ //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
\ No newline at end of file
diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md
index f4b61cbcf7dc..2163272e2542 100644
--- a/docs/Custom_Operation_for_GPUs.md
+++ b/docs/Custom_Operation_for_GPUs.md
@@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the
gradient. (And if you implement the interface to support vmat, it will
also be on the outer primitive).
-JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic.
+JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic.
XLA sharding goes in two phases: a sharding propagation phase and a partition phase.
-The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph.
+The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph.
For XLA to be able to shard our custom operations, it needs us to define 2 extra functions:
infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively.
The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding.
The partition() function will do a few things:
-- tell which input sharding will be expected. XLA will reshad if needed.
+- tell which input sharding will be expected. XLA will reshard if needed.
- tell the final version of the output sharding.
- give a function that will create the new instruction from the sharded inputs.
diff --git a/docs/_static/jax-hero.svg b/docs/_static/jax-hero.svg
new file mode 100644
index 000000000000..04626f43eacd
--- /dev/null
+++ b/docs/_static/jax-hero.svg
@@ -0,0 +1,118 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/_static/style.css b/docs/_static/style.css
index 2c1dfcbcbf08..36b54b8432f0 100644
--- a/docs/_static/style.css
+++ b/docs/_static/style.css
@@ -1,34 +1,280 @@
@import url("theme.css");
+@import url('https://fonts.googleapis.com/css2?family=Google+Sans');
+
+/* Base LP sidebar modifications */
+body:has(.hero) .sidebar-toggle,
+body:has(.hero) .bd-sidebar-secondary {
+ display: none !important;
+}
+
+body:has(.hero) .search-button {
+ display: flex !important;
+}
+
+body:has(.hero) .primary-toggle {
+ display: inline-block !important;
+}
+
+body:has(.hero) .prev-next-footer {
+ display: none;
+}
+
+body:has(.hero) .bd-article-container {
+ max-width: unset !important;
+}
+
+body:has(.hero) .bd-page-width {
+ max-width: unset !important;
+}
+
+body:has(.hero) .bd-article {
+ display: flex;
+ flex-direction: column;
+ padding: 0;
+}
+
+body:has(.hero) .bd-container {
+ flex-direction: column;
+}
+
+@media (min-width: 960px) {
+ body:has(.hero) .bd-header-article {
+ justify-content: center;
+ }
+
+ body:has(.hero) .header-article-items,
+ body:has(.hero) .bd-article > section {
+ max-width: 65rem !important;
+ align-self: center;
+ width: -moz-available;
+ width: -webkit-fill-available;
+ width: fill-available;
+ }
+}
+
+/* Custom CSS */
:root {
--block-bg-opacity: .5;
}
+.bd-main .bd-content .bd-article-container .bd-article:has(.hero) {
+ padding: 0;
+}
+
+.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * {
+ padding-inline: 2rem !important;
+}
+
+@media (max-width: 768px) {
+ .bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * {
+ padding-inline: 1rem !important;
+ }
+}
+
+.bd-main .bd-content .bd-article-container .bd-article:has(.hero) h1 {
+ display: none;
+}
+
.wy-side-nav-search {
background-color: #fff;
}
+.getting-started,
+.user-guides,
.installation {
- background-color: rgba(78, 150, 253, var(--block-bg-opacity));
+ background: #3C4043;
+ color: white;
+ height: 170px;
+ border: none !important;
+ border-radius: 12px;
+}
+
+.getting-started:hover,
+.user-guides:hover,
+.installation:hover {
+ background: #AECBFA;
+ color: #202124;
+ transform: unset !important;
+}
+
+.getting-started .sd-card-body,
+.user-guides .sd-card-body,
+.installation .sd-card-body {
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font: 500 24px 'Roboto', sans-serif;
+}
+
+.getting-started .sd-card-title,
+.user-guides .sd-card-title,
+.installation .sd-card-title {
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ gap: 12px;
+}
+
+.getting-started svg,
+.user-guides svg,
+.installation svg {
+ color: #8AB4F8;
+}
+
+.getting-started:hover svg,
+.user-guides:hover svg,
+.installation:hover svg {
+ color: #3C4043;
+}
+
+.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > .hero {
+ padding-inline: 2rem 0 !important;
}
-.getting-started {
- background-color: rgba(0, 169, 154, var(--block-bg-opacity));
+.hero {
+ display: grid;
+ grid: auto-flow / 1fr .6fr;
+ align-items: center;
+ background: rgb(32,33,36);
+ background: linear-gradient(90deg, rgba(32,33,36,1) 0%, rgba(39,45,56,1) 100%);
+ position: relative;
+ overflow: hidden;
+ border-radius: 24px;
}
-.user-guides {
- background-color: rgba(171, 0, 182, var(--block-bg-opacity));
+.hero > img {
+ position: absolute;
+ top: 0;
+ right: 0;
+ height: 100%;
+ background: transparent !important;
+}
+
+.hero-left {
+ padding-block: 24px;
+ display: flex;
+ flex-direction: column;
+}
+
+.hero-left img {
+ width: 100px;
+ height: auto;
+ position: relative;
+ margin-bottom: 16px;
+ background: transparent !important;
+}
+
+.hero-left h2 {
+ font: 500 32px 'Google Sans', 'Roboto', sans-serif;
+ color: white;
+ margin-top: 0;
+}
+
+.hero-left p {
+ font: 400 16px 'Roboto', sans-serif;
+ color: white;
+}
+
+@media (max-width: 1295px) {
+ .hero > img {
+ right: -75px;
+ }
+}
+
+@media (max-width: 750px) {
+ .hero {
+ grid: auto-flow / 1fr;
+ }
+
+ .hero-left {
+ padding-right: 2rem;
+ }
+
+ .hero > img {
+ display: none;
+ }
+}
+
+.product-offerings {
+ margin-block: 32px !important;
+}
+
+.product-offerings .sd-card-title {
+ font: 400 24px 'Google Sans', 'Roboto', sans-serif;
+}
+
+.color-cards {
+ background: #E8EAED;
+ color: #222832;
+ padding: 48px 12px 0 12px;
+ margin-bottom: 0 !important;
+ border-radius: 24px 24px 0 0;
+}
+
+.color-cards > div {
+ gap: 24px 0;
+}
+
+.color-cards + p {
+ background: #E8EAED;
+ padding: 24px 12px 48px 12px;
+ font-weight: 600;
+ color: #222832;
+ border-radius: 0 0 24px 24px;
+}
+
+.color-cards + p > a {
+ color: #222832;
+}
+
+.color-cards + p > a:hover,
+html[data-theme="dark"] .color-cards + p > a:hover {
+ color: #e89217;
+}
+
+html[data-theme="dark"] .color-cards,
+html[data-theme="dark"] .hero,
+html[data-theme="dark"] .color-cards + p,
+html[data-theme="dark"] .color-cards + p > a {
+ background: #202124;
+ color: white;
}
.ecosystem-grid {
font-size: smaller;
}
+.ecosystem-grid > div {
+ gap: 20px;
+}
+
+.ecosystem-grid .sd-col {
+ border: 1px solid #dadce0;
+ border-radius: 8px;
+ width: calc(50% - 10px);
+ padding: 16px;
+}
+
+.ecosystem-grid .sd-col > p {
+ display: flex;
+ flex-direction: column;
+ gap: 10px;
+}
+
+.ecosystem-grid .sd-col > p > svg {
+ color: #00897B;
+}
+
.ecosystem-grid ul {
list-style-type: none;
padding-inline-start: 0.5em;
}
+.ecosystem-grid a {
+ text-decoration: none;
+}
+
div.red-background pre {
background-color: rgba(244, 204, 204, var(--block-bg-opacity));
}
diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md
index 023dc8040954..c56e82c77450 100644
--- a/docs/advanced-autodiff.md
+++ b/docs/advanced-autodiff.md
@@ -350,7 +350,7 @@ This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \math
and so on.
-To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of the two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out.
+To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out.
## How it's made: Two foundational autodiff functions
@@ -475,7 +475,7 @@ where we use `CT a` to denote the type for the cotangent space for `a`. In words
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
-There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
+There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/).
@@ -1762,7 +1762,6 @@ print(grad(app, 1)(lambda x: x ** 2, 4.))
Refer to `fixed_point` above for another usage example.
**You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments.
-s
## Next steps
diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb
index 8b418b16f878..e620967de4b7 100644
--- a/docs/autodidax.ipynb
+++ b/docs/autodidax.ipynb
@@ -2797,7 +2797,7 @@
"representing unknown outputs, we need avals, which we get from the abstract\n",
"eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n",
"`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n",
- "weakrefs.)\n",
+ "`weakref`s.)\n",
"\n",
"That `process_primitive` logic applies to most primitives, but `xla_call_p`\n",
"requires recursive treatment. So we special-case its rule in a\n",
diff --git a/docs/autodidax.md b/docs/autodidax.md
index 9e726e5ed82e..1c16db80f608 100644
--- a/docs/autodidax.md
+++ b/docs/autodidax.md
@@ -2195,7 +2195,7 @@ output. If instead any input is unknown then we instead stage out into a
representing unknown outputs, we need avals, which we get from the abstract
eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
-weakrefs.)
+`weakref`s.)
That `process_primitive` logic applies to most primitives, but `xla_call_p`
requires recursive treatment. So we special-case its rule in a
diff --git a/docs/autodidax.py b/docs/autodidax.py
index f57af2cd96f2..f74617f31416 100644
--- a/docs/autodidax.py
+++ b/docs/autodidax.py
@@ -2187,7 +2187,7 @@ def full_lower(self):
# representing unknown outputs, we need avals, which we get from the abstract
# eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
# `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
-# weakrefs.)
+# `weakref`s.)
#
# That `process_primitive` logic applies to most primitives, but `xla_call_p`
# requires recursive treatment. So we special-case its rule in a
diff --git a/docs/control-flow.md b/docs/control-flow.md
new file mode 100644
index 000000000000..7cb959f3e434
--- /dev/null
+++ b/docs/control-flow.md
@@ -0,0 +1,394 @@
+---
+jupytext:
+ formats: md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
++++ {"id": "rg4CpMZ8c3ri"}
+
+(control-flow)=
+# Control flow and logical operators with JIT
+
+
+
+When executing eagerly (outside of `jit`), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators with `jit` is more complicated.
+
+In a nutshell, Python control flow and logical operators are evaluated at JIT compile time, such that the compiled function represents a single path through the [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph) (logical operators affect the path via short-circuiting). If the path depends on the values of the inputs, the function (by default) cannot be JIT compiled. The path may depend on the shape or dtype of the inputs, and the function is re-compiled every time it is called on an input with a new shape or dtype.
+
+```{code-cell}
+from jax import grad, jit
+import jax.numpy as jnp
+```
+
+For example, this works:
+
+```{code-cell}
+:id: OZ_BJX0CplNC
+:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c
+
+@jit
+def f(x):
+ for i in range(3):
+ x = 2 * x
+ return x
+
+print(f(3))
+```
+
++++ {"id": "22RzeJ4QqAuX"}
+
+So does this:
+
+```{code-cell}
+:id: pinVnmRWp6w6
+:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422
+
+@jit
+def g(x):
+ y = 0.
+ for i in range(x.shape[0]):
+ y = y + x[i]
+ return y
+
+print(g(jnp.array([1., 2., 3.])))
+```
+
++++ {"id": "TStltU2dqf8A"}
+
+But this doesn't, at least by default:
+
+```{code-cell}
+:id: 9z38AIKclRNM
+:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac
+:tags: [raises-exception]
+
+@jit
+def f(x):
+ if x < 3:
+ return 3. * x ** 2
+ else:
+ return -4 * x
+
+# This will fail!
+f(2)
+```
+
+Neither does this:
+
+```{code-cell}
+:tags: [raises-exception]
+
+@jit
+def g(x):
+ return (x > 0) and (x < 3)
+
+# This will fail!
+g(2)
+```
+
++++ {"id": "pIbr4TVPqtDN"}
+
+__What gives!?__
+
+When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.
+
+For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
+
+To get a view of your Python code that is valid for many different argument values, JAX traces it with the `ShapedArray` abstraction as input, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
+
+But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
+
+The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnames` (or `static_argnums`) argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
+
+```{code-cell}
+:id: -Tzp0H7Bt1Sn
+:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a
+
+def f(x):
+ if x < 3:
+ return 3. * x ** 2
+ else:
+ return -4 * x
+
+f = jit(f, static_argnames='x')
+
+print(f(2.))
+```
+
++++ {"id": "MHm1hIQAvBVs"}
+
+Here's another example, this time involving a loop:
+
+```{code-cell}
+:id: iwY86_JKvD6b
+:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937
+
+def f(x, n):
+ y = 0.
+ for i in range(n):
+ y = y + x[i]
+ return y
+
+f = jit(f, static_argnames='n')
+
+f(jnp.array([2., 3., 4.]), 2)
+```
+
++++ {"id": "nSPTOX8DvOeO"}
+
+In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation
+
++++ {"id": "wWdg8LTYwCW3"}
+
+️⚠️ **functions with argument-__value__ dependent shapes**
+
+These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
+
+```{code-cell}
+:id: Tqe9uLmUI_Gv
+:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883
+
+def example_fun(length, val):
+ return jnp.ones((length,)) * val
+# un-jit'd works fine
+print(example_fun(5, 4))
+```
+
+```{code-cell}
+:id: fOlR54XRgHpd
+:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71
+:tags: [raises-exception]
+
+bad_example_jit = jit(example_fun)
+# this will fail:
+bad_example_jit(10, 4)
+```
+
+```{code-cell}
+:id: kH0lOD4GgFyI
+:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade
+
+# static_argnames tells JAX to recompile on changes at these argument positions:
+good_example_jit = jit(example_fun, static_argnames='length')
+# first compile
+print(good_example_jit(10, 4))
+# recompiles
+print(good_example_jit(5, 4))
+```
+
++++ {"id": "MStx_r2oKxpp"}
+
+`static_argnames` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!
+
+Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
+
+```{code-cell}
+:id: m2ABpRd8K094
+:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3
+
+@jit
+def f(x):
+ print(x)
+ y = 2 * x
+ print(y)
+ return y
+f(2)
+```
+
++++ {"id": "uCDcWG4MnVn-"}
+
+## Structured control flow primitives
+
+There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:
+
+ - `lax.cond` _differentiable_
+ - `lax.while_loop` __fwd-mode-differentiable__
+ - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.
+ - `lax.scan` _differentiable_
+
++++ {"id": "Sd9xrLMXeK3A"}
+
+### `cond`
+python equivalent:
+
+```python
+def cond(pred, true_fun, false_fun, operand):
+ if pred:
+ return true_fun(operand)
+ else:
+ return false_fun(operand)
+```
+
+```{code-cell}
+:id: SGxz9JOWeiyH
+:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3
+
+from jax import lax
+
+operand = jnp.array([0.])
+lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
+# --> array([1.], dtype=float32)
+lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
+# --> array([-1.], dtype=float32)
+```
+
++++ {"id": "lIYdn1woOS1n"}
+
+`jax.lax` provides two other functions that allow branching on dynamic predicates:
+
+- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is
+ like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays
+ rather than as functions.
+- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is
+ like `lax.cond`, but allows switching between any number of callable choices.
+
+In addition, `jax.numpy` provides several numpy-style interfaces to these functions:
+
+- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with
+ three arguments is the numpy-style wrapper of `lax.select`.
+- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)
+ is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.
+- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has
+ an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather
+ than as functions. It is implemented in terms of multiple calls to `lax.select`.
+
++++ {"id": "xkOFAw24eOMg"}
+
+### `while_loop`
+
+python equivalent:
+```
+def while_loop(cond_fun, body_fun, init_val):
+ val = init_val
+ while cond_fun(val):
+ val = body_fun(val)
+ return val
+```
+
+```{code-cell}
+:id: jM-D39a-c436
+:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
+
+init_val = 0
+cond_fun = lambda x: x < 10
+body_fun = lambda x: x+1
+lax.while_loop(cond_fun, body_fun, init_val)
+# --> array(10, dtype=int32)
+```
+
++++ {"id": "apo3n3HAeQY_"}
+
+### `fori_loop`
+python equivalent:
+```
+def fori_loop(start, stop, body_fun, init_val):
+ val = init_val
+ for i in range(start, stop):
+ val = body_fun(i, val)
+ return val
+```
+
+```{code-cell}
+:id: dt3tUpOmeR8u
+:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380
+
+init_val = 0
+start = 0
+stop = 10
+body_fun = lambda i,x: x+i
+lax.fori_loop(start, stop, body_fun, init_val)
+# --> array(45, dtype=int32)
+```
+
++++ {"id": "SipXS5qiqk8e"}
+
+### Summary
+
+$$
+\begin{array} {r|rr}
+\hline \
+\textrm{construct}
+& \textrm{jit}
+& \textrm{grad} \\
+\hline \
+\textrm{if} & ❌ & ✔ \\
+\textrm{for} & ✔* & ✔\\
+\textrm{while} & ✔* & ✔\\
+\textrm{lax.cond} & ✔ & ✔\\
+\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
+\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
+\textrm{lax.scan} & ✔ & ✔\\
+\hline
+\end{array}
+$$
+
+
+
+$\ast$ = argument-value -independent loop condition - unrolls the loop
+
+
+
+## Logical operators
+
+`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`.
+
+For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar.
+
+```{code-cell}
+def python_check_positive_even(x):
+ is_even = x % 2 == 0
+ # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
+ return is_even and (x > 0)
+
+@jit
+def jax_check_positive_even(x):
+ is_even = x % 2 == 0
+ # `logical_and` does not short circuit, so `x > 0` is always evaluated.
+ return jnp.logical_and(is_even, x > 0)
+
+print(python_check_positive_even(24))
+print(jax_check_positive_even(24))
+```
+
+When the JAX version with `logical_and` is applied to an array, it returns elementwise values.
+
+```{code-cell}
+x = jnp.array([-1, 2, 5])
+print(jax_check_positive_even(x))
+```
+
+Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior.
+
+```{code-cell}
+:tags: [raises-exception]
+
+print(python_check_positive_even(x))
+```
+
++++ {"id": "izLTvT24dAq0"}
+
+## Python control flow + autodiff
+
+Remember that the above constraints on control flow and logical operators are relevant only with `jit`. If you just want to apply `grad` to your python functions, without `jit`, you can use regular Python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
+
+```{code-cell}
+:id: aAx0T3F8lLtu
+:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52
+
+def f(x):
+ if x < 3:
+ return 3. * x ** 2
+ else:
+ return -4 * x
+
+print(grad(f)(2.)) # ok!
+print(grad(f)(4.)) # ok!
+```
diff --git a/docs/deprecation.md b/docs/deprecation.md
index 385d31271421..603a027f5efc 100644
--- a/docs/deprecation.md
+++ b/docs/deprecation.md
@@ -18,6 +18,7 @@ This means we support at least:
* **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**.
* **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**.
* **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**.
+ * **Python 3.13** was released October 2024, and will be supported in new JAX releases at least until **July 2028**.
* All NumPy feature releases in the 24 months prior to each JAX release. For example:
@@ -25,6 +26,7 @@ This means we support at least:
* **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025**
* **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025**
* **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026**
+ * **NumPy 2.1** was released August 2024, and will be supported in new JAX releases at least until **August 2026**
* All SciPy feature releases in the 24 months prior to each JAX release. For example:
@@ -32,6 +34,7 @@ This means we support at least:
* **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**.
* **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**.
* **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**.
+ * **Scipy 1.14** was released June 2024, and will be supported in new JAX releases at least until **June 2026**.
JAX releases may support older versions of Python, NumPy, and SciPy than strictly required
by this policy, but support for older versions may be dropped at any time beyond the listed
diff --git a/docs/faq.rst b/docs/faq.rst
index 1d2bb204f24c..44267f6f5f7d 100644
--- a/docs/faq.rst
+++ b/docs/faq.rst
@@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of
Python control flow such as ``for`` loops. For a handful of loop iterations,
Python is OK, but if you need *many* loop iterations, you should rewrite your
code to make use of JAX's
-`structured control flow primitives `_
+`structured control flow primitives `_
(such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can
still use ``jit`` decorated functions *inside* the loop).
diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb
index f1a699b5c56c..72a2a6914fc0 100644
--- a/docs/ffi.ipynb
+++ b/docs/ffi.ipynb
@@ -139,8 +139,8 @@
"}\n",
"\n",
"// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare\n",
- "// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`\n",
- "// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.\n",
+ "// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`\n",
+ "// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.\n",
"XLA_FFI_DEFINE_HANDLER_SYMBOL(\n",
" RmsNorm, RmsNormImpl,\n",
" ffi::Ffi::Bind()\n",
diff --git a/docs/ffi.md b/docs/ffi.md
index dbe901237ed4..96b627675004 100644
--- a/docs/ffi.md
+++ b/docs/ffi.md
@@ -134,8 +134,8 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x,
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
-// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
-// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
+// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
+// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
diff --git a/docs/hero.html b/docs/hero.html
new file mode 100644
index 000000000000..a2ee3b8e206f
--- /dev/null
+++ b/docs/hero.html
@@ -0,0 +1,8 @@
+
+
+
+
High performance array computing
+
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
+
+
+
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 5f3bce5cf7da..ba8ebcbdd128 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,10 +1,22 @@
JAX: High performance array computing
=====================================
-JAX is a Python library for accelerator-oriented array computation and program transformation,
-designed for high-performance numerical computing and large-scale machine learning.
+.. raw:: html
+
+
+
+
+.. raw:: html
+ :file: hero.html
.. grid:: 3
+ :class-container: product-offerings
:margin: 0
:padding: 0
:gutter: 0
@@ -31,6 +43,7 @@ designed for high-performance numerical computing and large-scale machine learni
The same code executes on multiple backends, including CPU, GPU, & TPU
.. grid:: 3
+ :class-container: color-cards
.. grid-item-card:: :material-regular:`laptop_chromebook;2em` Installation
:columns: 12 6 6 4
@@ -59,7 +72,7 @@ JAX itself is narrowly-scoped and focuses on efficient array operations & progra
transformations. Built around JAX is an evolving ecosystem of machine learning and
numerical computing tools; the following is just a small sample of what is out there:
-.. grid:: 4
+.. grid:: 2
:class-container: ecosystem-grid
.. grid-item:: :material-outlined:`hub;2em` **Neural networks**
diff --git a/docs/installation.md b/docs/installation.md
index b7a56c48ec1f..78a6a5a5a444 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -35,6 +35,7 @@ The table below shows all supported platforms and installation options. Check if
| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no |
| Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a |
+| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no |
(install-cpu)=
@@ -230,6 +231,17 @@ JAX has experimental ROCm support. There are two ways to install JAX:
* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or
* Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_).
+(install-intel-gpu)=
+## Intel GPU
+
+Intel provides an experimental OneAPI plugin: intel-extension-for-openxla for Intel GPU hardware. For more details and installation instructions, refer to one of the following two methods:
+1. Pip installation: [JAX acceleration on Intel GPU](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md).
+2. Using [Intel's XLA Docker container](https://hub.docker.com/r/intel/intel-optimized-xla).
+
+Please report any issues related to:
+* JAX: [JAX issue tracker](https://github.com/jax-ml/jax/issues).
+* Intel's OpenXLA plugin: [Intel-extension-for-openxla issue tracker](https://github.com/intel/intel-extension-for-openxla/issues).
+
## Conda (community-supported)
### Conda installation
diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst
index 3922c92d98de..30553a360155 100644
--- a/docs/jax.numpy.rst
+++ b/docs/jax.numpy.rst
@@ -337,6 +337,7 @@ namespace; they are listed below.
promote_types
ptp
put
+ put_along_axis
quantile
r_
rad2deg
diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md
index 51322fda9476..5e5be308068a 100644
--- a/docs/jit-compilation.md
+++ b/docs/jit-compilation.md
@@ -170,7 +170,7 @@ jax.jit(g)(10, 20) # Raises an error
The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.
Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as
`shape` or `dtype`, and not via their values.
-For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
+For more detail on the interaction between Python control flow and JAX, see {ref}`control-flow`.
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical.
In that case, you can consider JIT-compiling only part of the function.
diff --git a/docs/key-concepts.md b/docs/key-concepts.md
index daab2c9fdde4..91f0c953462e 100644
--- a/docs/key-concepts.md
+++ b/docs/key-concepts.md
@@ -189,3 +189,43 @@ tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the le
in a tree.
You can learn more in the {ref}`working-with-pytrees` tutorial.
+
+(key-concepts-prngs)=
+## Pseudorandom numbers
+
+Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`:
+
+```{code-cell}
+from jax import random
+
+key = random.key(43)
+print(key)
+```
+
+The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions.
+Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.
+
+```{code-cell}
+print(random.normal(key))
+print(random.normal(key))
+```
+
+**The rule of thumb is: never reuse keys (unless you want identical outputs).**
+
+In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function:
+
+```{code-cell}
+for i in range(3):
+ new_key, subkey = random.split(key)
+ del key # The old key is consumed by split() -- we must never use it again.
+
+ val = random.normal(subkey)
+ del subkey # The subkey is consumed by normal().
+
+ print(f"draw {i}: {val}")
+ key = new_key # new_key is safe to use in the next iteration.
+```
+
+Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys.
+
+For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial.
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
index 71bd4527644a..02077d2a6b00 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb
+++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
@@ -34,7 +34,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
- "from jax import grad, jit\n",
+ "from jax import jit\n",
"from jax import lax\n",
"from jax import random\n",
"import jax\n",
@@ -865,920 +865,21 @@
"id": "MUycRNh6e50W"
},
"source": [
- "## 🔪 Random numbers"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "O8vvaVt3MRG2"
- },
- "source": [
- "> _If all scientific papers whose results are in doubt because of bad\n",
- "> `rand()`s were to disappear from library shelves, there would be a\n",
- "> gap on each shelf about as big as your fist._ - Numerical Recipes"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Qikt9pPW9L5K"
- },
- "source": [
- "### RNGs and state\n",
- "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "id": "rr9FeP41fynt",
- "outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.2726690048900553\n",
- "0.6304191979771206\n",
- "0.6933648856441533\n"
- ]
- }
- ],
- "source": [
- "print(np.random.random())\n",
- "print(np.random.random())\n",
- "print(np.random.random())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ORMVVGZJgSVi"
- },
- "source": [
- "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {
- "id": "7Pyp2ajzfPO2"
- },
- "outputs": [],
- "source": [
- "np.random.seed(0)\n",
- "rng_state = np.random.get_state()\n",
- "# print(rng_state)\n",
- "# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
- "# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n",
- "# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "aJIxHVXCiM6m"
- },
- "source": [
- "This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "id": "GAHaDCYafpAF"
- },
- "outputs": [],
- "source": [
- "_ = np.random.uniform()\n",
- "rng_state = np.random.get_state()\n",
- "#print(rng_state)\n",
- "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
- "# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
- "\n",
- "# Let's exhaust the entropy in this PRNG statevector\n",
- "for i in range(311):\n",
- " _ = np.random.uniform()\n",
- "rng_state = np.random.get_state()\n",
- "#print(rng_state)\n",
- "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
- "# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
- "\n",
- "# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
- "_ = np.random.uniform()\n",
- "rng_state = np.random.get_state()\n",
- "# print(rng_state)\n",
- "# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n",
- "# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "N_mWnleNogps"
- },
- "source": [
- "The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n",
- "\n",
- "The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Uvq7nV-j4vKK"
- },
- "source": [
- "### JAX PRNG"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "COjzGBpO4tzL"
- },
- "source": [
- "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
- "\n",
- "The random state is described by a special array element that we call a __key__:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {
- "id": "yPHE7KTWgAWs",
- "outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([0, 0], dtype=uint32)"
- ]
- },
- "execution_count": 23,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "key = random.key(0)\n",
- "key"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XjYyWYNfq0hW"
- },
- "source": [
- "JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n",
+ "## 🔪 Random numbers\n",
"\n",
- "Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "id": "7zUdQMynoE5e",
- "outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[-0.20584226]\n",
- "[0 0]\n",
- "[-0.20584226]\n",
- "[0 0]\n"
- ]
- }
- ],
- "source": [
- "print(random.normal(key, shape=(1,)))\n",
- "print(key)\n",
- "# No no no!\n",
- "print(random.normal(key, shape=(1,)))\n",
- "print(key)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hQN9van8rJgd"
- },
- "source": [
- "Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "id": "ASj0_rSzqgGh",
- "outputId": "2f13f249-85d1-47bb-d503-823eca6961aa"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "old key [0 0]\n",
- " \\---SPLIT --> new key [4146024105 967050713]\n",
- " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n"
- ]
- }
- ],
- "source": [
- "print(\"old key\", key)\n",
- "key, subkey = random.split(key)\n",
- "normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
- "print(r\" \\---SPLIT --> new key \", key)\n",
- "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tqtFVE4MthO3"
- },
- "source": [
- "We propagate the __key__ and make new __subkeys__ whenever we need a new random number:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "id": "jbC34XLor2Ek",
- "outputId": "4059a2e2-0205-40bc-ad55-17709d538871"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "old key [4146024105 967050713]\n",
- " \\---SPLIT --> new key [2384771982 3928867769]\n",
- " \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n"
- ]
- }
- ],
- "source": [
- "print(\"old key\", key)\n",
- "key, subkey = random.split(key)\n",
- "normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
- "print(r\" \\---SPLIT --> new key \", key)\n",
- "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0KLYUluz3lN3"
- },
- "source": [
- "We can generate more than one __subkey__ at a time:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {
- "id": "lEi08PJ4tfkX",
- "outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[-0.37533438]\n",
- "[0.98645043]\n",
- "[0.14553197]\n"
- ]
- }
- ],
- "source": [
- "key, *subkeys = random.split(key, 4)\n",
- "for subkey in subkeys:\n",
- " print(random.normal(subkey, shape=(1,)))"
+ "JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial."
]
},
{
"cell_type": "markdown",
+ "id": "1dc0e6b2",
"metadata": {
"id": "rg4CpMZ8c3ri"
},
"source": [
- "## 🔪 Control flow"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "izLTvT24dAq0"
- },
- "source": [
- "### ✔ Python control_flow + autodiff ✔\n",
- "\n",
- "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {
- "id": "aAx0T3F8lLtu",
- "outputId": "383b7bfa-1634-4d23-8497-49cb9452ca52"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "12.0\n",
- "-4.0\n"
- ]
- }
- ],
- "source": [
- "def f(x):\n",
- " if x < 3:\n",
- " return 3. * x ** 2\n",
- " else:\n",
- " return -4 * x\n",
- "\n",
- "print(grad(f)(2.)) # ok!\n",
- "print(grad(f)(4.)) # ok!"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hIfPT7WMmZ2H"
- },
- "source": [
- "### Python control flow + JIT\n",
- "\n",
- "Using control flow with `jit` is more complicated, and by default it has more constraints.\n",
- "\n",
- "This works:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {
- "id": "OZ_BJX0CplNC",
- "outputId": "60c902a2-eba1-49d7-c8c8-2f68616d660c"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "24\n"
- ]
- }
- ],
- "source": [
- "@jit\n",
- "def f(x):\n",
- " for i in range(3):\n",
- " x = 2 * x\n",
- " return x\n",
- "\n",
- "print(f(3))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "22RzeJ4QqAuX"
- },
- "source": [
- "So does this:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {
- "id": "pinVnmRWp6w6",
- "outputId": "25e06cf2-474f-4782-af7c-4f5514b64422"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "6.0\n"
- ]
- }
- ],
- "source": [
- "@jit\n",
- "def g(x):\n",
- " y = 0.\n",
- " for i in range(x.shape[0]):\n",
- " y = y + x[i]\n",
- " return y\n",
- "\n",
- "print(g(jnp.array([1., 2., 3.])))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TStltU2dqf8A"
- },
- "source": [
- "But this doesn't, at least by default:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {
- "id": "9z38AIKclRNM",
- "outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac",
- "tags": [
- "raises-exception"
- ]
- },
- "outputs": [
- {
- "ename": "ConcretizationTypeError",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
- ]
- }
- ],
- "source": [
- "@jit\n",
- "def f(x):\n",
- " if x < 3:\n",
- " return 3. * x ** 2\n",
- " else:\n",
- " return -4 * x\n",
- "\n",
- "# This will fail!\n",
- "f(2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pIbr4TVPqtDN"
- },
- "source": [
- "__What gives!?__\n",
- "\n",
- "When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
- "\n",
- "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
- "\n",
- "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
- "\n",
- "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
- "\n",
- "But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
- "\n",
- "The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {
- "id": "-Tzp0H7Bt1Sn",
- "outputId": "f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "12.0\n"
- ]
- }
- ],
- "source": [
- "def f(x):\n",
- " if x < 3:\n",
- " return 3. * x ** 2\n",
- " else:\n",
- " return -4 * x\n",
- "\n",
- "f = jit(f, static_argnums=(0,))\n",
- "\n",
- "print(f(2.))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "MHm1hIQAvBVs"
- },
- "source": [
- "Here's another example, this time involving a loop:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {
- "id": "iwY86_JKvD6b",
- "outputId": "48f9b51f-bd32-466f-eac1-cd23444ce937"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array(5., dtype=float32)"
- ]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "def f(x, n):\n",
- " y = 0.\n",
- " for i in range(n):\n",
- " y = y + x[i]\n",
- " return y\n",
- "\n",
- "f = jit(f, static_argnums=(1,))\n",
- "\n",
- "f(jnp.array([2., 3., 4.]), 2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nSPTOX8DvOeO"
- },
- "source": [
- "In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wWdg8LTYwCW3"
- },
- "source": [
- "️⚠️ **functions with argument-__value__ dependent shapes**\n",
- "\n",
- "These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {
- "id": "Tqe9uLmUI_Gv",
- "outputId": "989be121-dfce-4bb3-c78e-a10829c5f883"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[4. 4. 4. 4. 4.]\n"
- ]
- }
- ],
- "source": [
- "def example_fun(length, val):\n",
- " return jnp.ones((length,)) * val\n",
- "# un-jit'd works fine\n",
- "print(example_fun(5, 4))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {
- "id": "fOlR54XRgHpd",
- "outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71",
- "tags": [
- "raises-exception"
- ]
- },
- "outputs": [
- {
- "ename": "TypeError",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Tracedwith,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n"
- ]
- }
- ],
- "source": [
- "bad_example_jit = jit(example_fun)\n",
- "# this will fail:\n",
- "bad_example_jit(10, 4)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 36,
- "metadata": {
- "id": "kH0lOD4GgFyI",
- "outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
- "[4. 4. 4. 4. 4.]\n"
- ]
- }
- ],
- "source": [
- "# static_argnums tells JAX to recompile on changes at these argument positions:\n",
- "good_example_jit = jit(example_fun, static_argnums=(0,))\n",
- "# first compile\n",
- "print(good_example_jit(10, 4))\n",
- "# recompiles\n",
- "print(good_example_jit(5, 4))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "MStx_r2oKxpp"
- },
- "source": [
- "`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n",
- "\n",
- "Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {
- "id": "m2ABpRd8K094",
- "outputId": "4f7ebe17-ade4-4e18-bd8c-4b24087c33c3"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Tracedwith\n",
- "Tracedwith\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "Array(4, dtype=int32, weak_type=True)"
- ]
- },
- "execution_count": 37,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "@jit\n",
- "def f(x):\n",
- " print(x)\n",
- " y = 2 * x\n",
- " print(y)\n",
- " return y\n",
- "f(2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "uCDcWG4MnVn-"
- },
- "source": [
- "### Structured control flow primitives\n",
- "\n",
- "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
- "\n",
- " - `lax.cond` _differentiable_\n",
- " - `lax.while_loop` __fwd-mode-differentiable__\n",
- " - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.\n",
- " - `lax.scan` _differentiable_"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Sd9xrLMXeK3A"
- },
- "source": [
- "#### `cond`\n",
- "python equivalent:\n",
- "\n",
- "```python\n",
- "def cond(pred, true_fun, false_fun, operand):\n",
- " if pred:\n",
- " return true_fun(operand)\n",
- " else:\n",
- " return false_fun(operand)\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 38,
- "metadata": {
- "id": "SGxz9JOWeiyH",
- "outputId": "942a8d0e-5ff6-4702-c499-b3941f529ca3"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([-1.], dtype=float32)"
- ]
- },
- "execution_count": 38,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from jax import lax\n",
- "\n",
- "operand = jnp.array([0.])\n",
- "lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n",
- "# --> array([1.], dtype=float32)\n",
- "lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n",
- "# --> array([-1.], dtype=float32)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e6622244",
- "metadata": {
- "id": "lIYdn1woOS1n"
- },
- "source": [
- "`jax.lax` provides two other functions that allow branching on dynamic predicates:\n",
- "\n",
- "- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is\n",
- " like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays\n",
- " rather than as functions.\n",
- "- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is\n",
- " like `lax.cond`, but allows switching between any number of callable choices.\n",
- "\n",
- "In addition, `jax.numpy` provides several numpy-style interfaces to these functions:\n",
- "\n",
- "- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with\n",
- " three arguments is the numpy-style wrapper of `lax.select`.\n",
- "- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)\n",
- " is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.\n",
- "- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has\n",
- " an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather\n",
- " than as functions. It is implemented in terms of multiple calls to `lax.select`."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "xkOFAw24eOMg"
- },
- "source": [
- "#### `while_loop`\n",
- "\n",
- "python equivalent:\n",
- "```\n",
- "def while_loop(cond_fun, body_fun, init_val):\n",
- " val = init_val\n",
- " while cond_fun(val):\n",
- " val = body_fun(val)\n",
- " return val\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 39,
- "metadata": {
- "id": "jM-D39a-c436",
- "outputId": "552fe42f-4d32-4e25-c8c2-b951160a3f4e"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array(10, dtype=int32, weak_type=True)"
- ]
- },
- "execution_count": 39,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "init_val = 0\n",
- "cond_fun = lambda x: x < 10\n",
- "body_fun = lambda x: x+1\n",
- "lax.while_loop(cond_fun, body_fun, init_val)\n",
- "# --> array(10, dtype=int32)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "apo3n3HAeQY_"
- },
- "source": [
- "#### `fori_loop`\n",
- "python equivalent:\n",
- "```\n",
- "def fori_loop(start, stop, body_fun, init_val):\n",
- " val = init_val\n",
- " for i in range(start, stop):\n",
- " val = body_fun(i, val)\n",
- " return val\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 40,
- "metadata": {
- "id": "dt3tUpOmeR8u",
- "outputId": "7819ca7c-1433-4d85-b542-f6159b0e8380"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array(45, dtype=int32, weak_type=True)"
- ]
- },
- "execution_count": 40,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "init_val = 0\n",
- "start = 0\n",
- "stop = 10\n",
- "body_fun = lambda i,x: x+i\n",
- "lax.fori_loop(start, stop, body_fun, init_val)\n",
- "# --> array(45, dtype=int32)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "SipXS5qiqk8e"
- },
- "source": [
- "#### Summary\n",
- "\n",
- "$$\n",
- "\\begin{array} {r|rr}\n",
- "\\hline \\\n",
- "\\textrm{construct}\n",
- "& \\textrm{jit}\n",
- "& \\textrm{grad} \\\\\n",
- "\\hline \\\n",
- "\\textrm{if} & ❌ & ✔ \\\\\n",
- "\\textrm{for} & ✔* & ✔\\\\\n",
- "\\textrm{while} & ✔* & ✔\\\\\n",
- "\\textrm{lax.cond} & ✔ & ✔\\\\\n",
- "\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
- "\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
- "\\textrm{lax.scan} & ✔ & ✔\\\\\n",
- "\\hline\n",
- "\\end{array}\n",
- "$$\n",
- "\n",
- "\n",
- "\n",
- "$\\ast$ = argument-value -independent loop condition - unrolls the loop\n",
+ "## 🔪 Control flow\n",
"\n",
- " "
+ "Moved to {ref}`control-flow`."
]
},
{
@@ -2209,6 +1310,9 @@
" ```\n",
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
"\n",
+ "## 🔪 Sharp bits covered in tutorials\n",
+ "- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n",
+ "- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.\n",
"\n",
"## Fin.\n",
"\n",
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md
index 741fa3af063c..f35c5ead13b7 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.md
+++ b/docs/notebooks/Common_Gotchas_in_JAX.md
@@ -31,7 +31,7 @@ JAX works great for many numerical and scientific programs, but __only if they a
:id: GoK_PCxPeYcy
import numpy as np
-from jax import grad, jit
+from jax import jit
from jax import lax
from jax import random
import jax
@@ -384,480 +384,13 @@ jnp.sum(jnp.array(x))
## 🔪 Random numbers
-+++ {"id": "O8vvaVt3MRG2"}
-
-> _If all scientific papers whose results are in doubt because of bad
-> `rand()`s were to disappear from library shelves, there would be a
-> gap on each shelf about as big as your fist._ - Numerical Recipes
-
-+++ {"id": "Qikt9pPW9L5K"}
-
-### RNGs and state
-You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
-
-```{code-cell} ipython3
-:id: rr9FeP41fynt
-:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745
-
-print(np.random.random())
-print(np.random.random())
-print(np.random.random())
-```
-
-+++ {"id": "ORMVVGZJgSVi"}
-
-Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up.
-
-```{code-cell} ipython3
-:id: 7Pyp2ajzfPO2
-
-np.random.seed(0)
-rng_state = np.random.get_state()
-# print(rng_state)
-# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
-# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
-# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
-```
-
-+++ {"id": "aJIxHVXCiM6m"}
-
-This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector:
-
-```{code-cell} ipython3
-:id: GAHaDCYafpAF
-
-_ = np.random.uniform()
-rng_state = np.random.get_state()
-#print(rng_state)
-# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
-# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
-
-# Let's exhaust the entropy in this PRNG statevector
-for i in range(311):
- _ = np.random.uniform()
-rng_state = np.random.get_state()
-#print(rng_state)
-# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
-# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
-
-# Next call iterates the RNG state for a new batch of fake "entropy".
-_ = np.random.uniform()
-rng_state = np.random.get_state()
-# print(rng_state)
-# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
-# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
-```
-
-+++ {"id": "N_mWnleNogps"}
-
-The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.
-
-The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow.
-
-+++ {"id": "Uvq7nV-j4vKK"}
-
-### JAX PRNG
-
-+++ {"id": "COjzGBpO4tzL"}
-
-JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
-
-The random state is described by a special array element that we call a __key__:
-
-```{code-cell} ipython3
-:id: yPHE7KTWgAWs
-:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
-
-key = random.key(0)
-key
-```
-
-+++ {"id": "XjYyWYNfq0hW"}
-
-JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!
-
-Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:
-
-```{code-cell} ipython3
-:id: 7zUdQMynoE5e
-:outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805
-
-print(random.normal(key, shape=(1,)))
-print(key)
-# No no no!
-print(random.normal(key, shape=(1,)))
-print(key)
-```
-
-+++ {"id": "hQN9van8rJgd"}
-
-Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:
-
-```{code-cell} ipython3
-:id: ASj0_rSzqgGh
-:outputId: 2f13f249-85d1-47bb-d503-823eca6961aa
-
-print("old key", key)
-key, subkey = random.split(key)
-normal_pseudorandom = random.normal(subkey, shape=(1,))
-print(r" \---SPLIT --> new key ", key)
-print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
-```
-
-+++ {"id": "tqtFVE4MthO3"}
-
-We propagate the __key__ and make new __subkeys__ whenever we need a new random number:
-
-```{code-cell} ipython3
-:id: jbC34XLor2Ek
-:outputId: 4059a2e2-0205-40bc-ad55-17709d538871
-
-print("old key", key)
-key, subkey = random.split(key)
-normal_pseudorandom = random.normal(subkey, shape=(1,))
-print(r" \---SPLIT --> new key ", key)
-print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
-```
-
-+++ {"id": "0KLYUluz3lN3"}
-
-We can generate more than one __subkey__ at a time:
-
-```{code-cell} ipython3
-:id: lEi08PJ4tfkX
-:outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01
-
-key, *subkeys = random.split(key, 4)
-for subkey in subkeys:
- print(random.normal(subkey, shape=(1,)))
-```
+JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial.
+++ {"id": "rg4CpMZ8c3ri"}
## 🔪 Control flow
-+++ {"id": "izLTvT24dAq0"}
-
-### ✔ Python control_flow + autodiff ✔
-
-If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
-
-```{code-cell} ipython3
-:id: aAx0T3F8lLtu
-:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52
-
-def f(x):
- if x < 3:
- return 3. * x ** 2
- else:
- return -4 * x
-
-print(grad(f)(2.)) # ok!
-print(grad(f)(4.)) # ok!
-```
-
-+++ {"id": "hIfPT7WMmZ2H"}
-
-### Python control flow + JIT
-
-Using control flow with `jit` is more complicated, and by default it has more constraints.
-
-This works:
-
-```{code-cell} ipython3
-:id: OZ_BJX0CplNC
-:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c
-
-@jit
-def f(x):
- for i in range(3):
- x = 2 * x
- return x
-
-print(f(3))
-```
-
-+++ {"id": "22RzeJ4QqAuX"}
-
-So does this:
-
-```{code-cell} ipython3
-:id: pinVnmRWp6w6
-:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422
-
-@jit
-def g(x):
- y = 0.
- for i in range(x.shape[0]):
- y = y + x[i]
- return y
-
-print(g(jnp.array([1., 2., 3.])))
-```
-
-+++ {"id": "TStltU2dqf8A"}
-
-But this doesn't, at least by default:
-
-```{code-cell} ipython3
-:id: 9z38AIKclRNM
-:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac
-:tags: [raises-exception]
-
-@jit
-def f(x):
- if x < 3:
- return 3. * x ** 2
- else:
- return -4 * x
-
-# This will fail!
-f(2)
-```
-
-+++ {"id": "pIbr4TVPqtDN"}
-
-__What gives!?__
-
-When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.
-
-For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
-
-To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
-
-By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
-
-But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
-
-The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
-
-```{code-cell} ipython3
-:id: -Tzp0H7Bt1Sn
-:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a
-
-def f(x):
- if x < 3:
- return 3. * x ** 2
- else:
- return -4 * x
-
-f = jit(f, static_argnums=(0,))
-
-print(f(2.))
-```
-
-+++ {"id": "MHm1hIQAvBVs"}
-
-Here's another example, this time involving a loop:
-
-```{code-cell} ipython3
-:id: iwY86_JKvD6b
-:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937
-
-def f(x, n):
- y = 0.
- for i in range(n):
- y = y + x[i]
- return y
-
-f = jit(f, static_argnums=(1,))
-
-f(jnp.array([2., 3., 4.]), 2)
-```
-
-+++ {"id": "nSPTOX8DvOeO"}
-
-In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation
-
-+++ {"id": "wWdg8LTYwCW3"}
-
-️⚠️ **functions with argument-__value__ dependent shapes**
-
-These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
-
-```{code-cell} ipython3
-:id: Tqe9uLmUI_Gv
-:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883
-
-def example_fun(length, val):
- return jnp.ones((length,)) * val
-# un-jit'd works fine
-print(example_fun(5, 4))
-```
-
-```{code-cell} ipython3
-:id: fOlR54XRgHpd
-:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71
-:tags: [raises-exception]
-
-bad_example_jit = jit(example_fun)
-# this will fail:
-bad_example_jit(10, 4)
-```
-
-```{code-cell} ipython3
-:id: kH0lOD4GgFyI
-:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade
-
-# static_argnums tells JAX to recompile on changes at these argument positions:
-good_example_jit = jit(example_fun, static_argnums=(0,))
-# first compile
-print(good_example_jit(10, 4))
-# recompiles
-print(good_example_jit(5, 4))
-```
-
-+++ {"id": "MStx_r2oKxpp"}
-
-`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!
-
-Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
-
-```{code-cell} ipython3
-:id: m2ABpRd8K094
-:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3
-
-@jit
-def f(x):
- print(x)
- y = 2 * x
- print(y)
- return y
-f(2)
-```
-
-+++ {"id": "uCDcWG4MnVn-"}
-
-### Structured control flow primitives
-
-There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:
-
- - `lax.cond` _differentiable_
- - `lax.while_loop` __fwd-mode-differentiable__
- - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.
- - `lax.scan` _differentiable_
-
-+++ {"id": "Sd9xrLMXeK3A"}
-
-#### `cond`
-python equivalent:
-
-```python
-def cond(pred, true_fun, false_fun, operand):
- if pred:
- return true_fun(operand)
- else:
- return false_fun(operand)
-```
-
-```{code-cell} ipython3
-:id: SGxz9JOWeiyH
-:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3
-
-from jax import lax
-
-operand = jnp.array([0.])
-lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
-# --> array([1.], dtype=float32)
-lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
-# --> array([-1.], dtype=float32)
-```
-
-+++ {"id": "lIYdn1woOS1n"}
-
-`jax.lax` provides two other functions that allow branching on dynamic predicates:
-
-- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is
- like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays
- rather than as functions.
-- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is
- like `lax.cond`, but allows switching between any number of callable choices.
-
-In addition, `jax.numpy` provides several numpy-style interfaces to these functions:
-
-- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with
- three arguments is the numpy-style wrapper of `lax.select`.
-- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)
- is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.
-- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has
- an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather
- than as functions. It is implemented in terms of multiple calls to `lax.select`.
-
-+++ {"id": "xkOFAw24eOMg"}
-
-#### `while_loop`
-
-python equivalent:
-```
-def while_loop(cond_fun, body_fun, init_val):
- val = init_val
- while cond_fun(val):
- val = body_fun(val)
- return val
-```
-
-```{code-cell} ipython3
-:id: jM-D39a-c436
-:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
-
-init_val = 0
-cond_fun = lambda x: x < 10
-body_fun = lambda x: x+1
-lax.while_loop(cond_fun, body_fun, init_val)
-# --> array(10, dtype=int32)
-```
-
-+++ {"id": "apo3n3HAeQY_"}
-
-#### `fori_loop`
-python equivalent:
-```
-def fori_loop(start, stop, body_fun, init_val):
- val = init_val
- for i in range(start, stop):
- val = body_fun(i, val)
- return val
-```
-
-```{code-cell} ipython3
-:id: dt3tUpOmeR8u
-:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380
-
-init_val = 0
-start = 0
-stop = 10
-body_fun = lambda i,x: x+i
-lax.fori_loop(start, stop, body_fun, init_val)
-# --> array(45, dtype=int32)
-```
-
-+++ {"id": "SipXS5qiqk8e"}
-
-#### Summary
-
-$$
-\begin{array} {r|rr}
-\hline \
-\textrm{construct}
-& \textrm{jit}
-& \textrm{grad} \\
-\hline \
-\textrm{if} & ❌ & ✔ \\
-\textrm{for} & ✔* & ✔\\
-\textrm{while} & ✔* & ✔\\
-\textrm{lax.cond} & ✔ & ✔\\
-\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
-\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
-\textrm{lax.scan} & ✔ & ✔\\
-\hline
-\end{array}
-$$
-
-
-
-$\ast$ = argument-value -independent loop condition - unrolls the loop
-
-
+Moved to {ref}`control-flow`.
+++ {"id": "OxLsZUyRt_kF"}
@@ -1145,6 +678,9 @@ Many such cases are discussed in detail in the sections above; here we list seve
```
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
+## 🔪 Sharp bits covered in tutorials
+- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.
+- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.
## Fin.
diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb
index 37c27ce2728a..d73b0d4c0f3e 100644
--- a/docs/notebooks/shard_map.ipynb
+++ b/docs/notebooks/shard_map.ipynb
@@ -864,7 +864,7 @@
"Indeed, this implementation is often used on both TPU and GPU!\n",
"\n",
"The reason `psum_scatter` can require about half the communication as a full\n",
- "`psum` is illustrated the `ppermute` section.\n",
+ "`psum` is illustrated in the `ppermute` section.\n",
"\n",
"Another intuition is that we can use `psum_scatter` to implement a distributed\n",
"matrix multiplication with inputs and outputs sharded over the same axis. In\n",
diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md
index 47b11079e27d..c52cf0e6d22b 100644
--- a/docs/notebooks/shard_map.md
+++ b/docs/notebooks/shard_map.md
@@ -627,7 +627,7 @@ def psum(x, axis_name):
Indeed, this implementation is often used on both TPU and GPU!
The reason `psum_scatter` can require about half the communication as a full
-`psum` is illustrated the `ppermute` section.
+`psum` is illustrated in the `ppermute` section.
Another intuition is that we can use `psum_scatter` to implement a distributed
matrix multiplication with inputs and outputs sharded over the same axis. In
diff --git a/docs/random-numbers.md b/docs/random-numbers.md
index 2ad1eadb0968..00f77e3473bb 100644
--- a/docs/random-numbers.md
+++ b/docs/random-numbers.md
@@ -17,6 +17,10 @@ kernelspec:
+> _If all scientific papers whose results are in doubt because of bad
+> `rand()`s were to disappear from library shelves, there would be a
+> gap on each shelf about as big as your fist._ - Numerical Recipes
+
In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.
@@ -35,6 +39,19 @@ import numpy as np
np.random.seed(0)
```
+Repeated calls to NumPy's stateful pseudorandom number generators (PRNGs) mutate the global state and give a stream of pseudorandom numbers:
+
+```{code-cell}
+:id: rr9FeP41fynt
+:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745
+
+print(np.random.random())
+print(np.random.random())
+print(np.random.random())
+```
+
+Underneath the hood, NumPy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this "entropy" has been used up.
+
You can inspect the content of the state using the following command.
```{code-cell}
@@ -109,7 +126,7 @@ Further, when executing in multi-device environments, execution efficiency would
### Explicit random state
-To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`:
+To avoid these issues, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`:
```{code-cell}
from jax import random
@@ -137,6 +154,7 @@ Re-using the same key, even with different {mod}`~jax.random` APIs, can result i
**The rule of thumb is: never reuse keys (unless you want identical outputs).**
+JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function:
```{code-cell}
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 41d8aa6d9ee7..bfbb4e271d42 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,7 +1,8 @@
absl-py
ipython>=8.8.0 # 8.7.0 has ipython3 lexer error
+pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling
sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1
-sphinx-book-theme>=1.0.1 # Older versions fail to pin pydata-sphinx-theme
+sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15
sphinx-copybutton>=0.5.0
sphinx-remove-toctrees
sphinx-design
diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md
index 2ff82e0431e2..fe84fc0d7f0a 100644
--- a/docs/stateful-computations.md
+++ b/docs/stateful-computations.md
@@ -12,6 +12,7 @@ kernelspec:
name: python3
---
+(stateful-computations)=
# Stateful computations
diff --git a/docs/tutorials.rst b/docs/tutorials.rst
index a31517155e1a..c9c2fdb1dcc7 100644
--- a/docs/tutorials.rst
+++ b/docs/tutorials.rst
@@ -16,6 +16,7 @@ Tutorials
working-with-pytrees
sharded-computation
stateful-computations
+ control-flow
.. toctree::
:maxdepth: 1
diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD
index 6e4647b5e491..b3cb995aae21 100644
--- a/examples/jax_cpp/BUILD
+++ b/examples/jax_cpp/BUILD
@@ -26,8 +26,13 @@ cc_binary(
"@tsl//tsl/platform:platform_port",
"@xla//xla:literal",
"@xla//xla:literal_util",
+ "@xla//xla/hlo/builder:xla_computation",
+ "@xla//xla/hlo/ir:hlo",
"@xla//xla/pjrt:pjrt_client",
- "@xla//xla/pjrt/cpu:cpu_client",
+ "@xla//xla/pjrt:pjrt_executable",
+ "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
+ "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
+ "@xla//xla/service:hlo_module_config",
"@xla//xla/tools:hlo_module_loader",
],
)
diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc
index 2a8f8d4debba..ceac2cd2d7c9 100644
--- a/examples/jax_cpp/main.cc
+++ b/examples/jax_cpp/main.cc
@@ -36,15 +36,21 @@ limitations under the License.
// }
// )
+#include
#include
#include
#include
#include "third_party/absl/status/statusor.h"
+#include "xla/hlo/builder/xla_computation.h"
+#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
-#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"
+#include "xla/pjrt/pjrt_executable.h"
+#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
+#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
+#include "xla/service/hlo_module_config.h"
#include "xla/tools/hlo_module_loader.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
@@ -66,8 +72,10 @@ int main(int argc, char** argv) {
// Run it using JAX C++ Runtime (PJRT).
// Get a CPU client.
+ xla::CpuClientOptions options;
+ options.asynchronous = true;
std::unique_ptr client =
- xla::GetTfrtCpuClient(/*asynchronous=*/true).value();
+ xla::GetXlaPjrtCpuClient(options).value();
// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);
diff --git a/jax/BUILD b/jax/BUILD
index 0da99677dc7b..64bfa627f42e 100644
--- a/jax/BUILD
+++ b/jax/BUILD
@@ -451,6 +451,7 @@ pytype_strict_library(
":deprecations",
":dtypes",
":effects",
+ ":mesh",
":pretty_printer",
":source_info_util",
":traceback_util",
@@ -1047,7 +1048,7 @@ pytype_library(
"experimental/array_api/*.py",
],
),
- visibility = [":internal"] + jax_visibility("array_api"),
+ visibility = [":internal"],
deps = [
":jax",
],
diff --git a/jax/_src/array.py b/jax/_src/array.py
index cf346067ea31..d8182976254e 100644
--- a/jax/_src/array.py
+++ b/jax/_src/array.py
@@ -1035,7 +1035,7 @@ def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
- self.sharding._normalized_spec(self.ndim)))
+ self.sharding.spec._normalized_spec(self.ndim)))
else:
return self.aval
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py
index 16da61d75b3f..3bc592d88246 100644
--- a/jax/_src/blocked_sampler.py
+++ b/jax/_src/blocked_sampler.py
@@ -23,7 +23,7 @@
Shape = random.Shape
class SampleFn(Protocol):
- def __call__(self, key: random.KeyArrayLike, *args, shape: Shape,
+ def __call__(self, key: ArrayLike, *args, shape: Shape,
**kwargs) -> Array:
...
@@ -43,7 +43,7 @@ def _compute_scalar_index(iteration_index: Sequence[int],
def blocked_fold_in(
- global_key: random.KeyArrayLike,
+ global_key: ArrayLike,
total_size: Shape,
block_size: Shape,
tile_size: Shape,
diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py
index 6e025653b81d..324fa85f81ed 100644
--- a/jax/_src/cache_key.py
+++ b/jax/_src/cache_key.py
@@ -21,6 +21,7 @@
from typing import cast as type_cast
from jax._src import config
+from jax._src.lib import version as jaxlib_version
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
@@ -225,6 +226,8 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
debug_options.xla_dump_hlo_as_long_text = False
debug_options.xla_dump_disable_metadata = False
debug_options.xla_dump_hlo_pipeline_re = ""
+ if jaxlib_version > (0, 4, 35):
+ debug_options.xla_gpu_experimental_autotune_cache_mode = 0
# Optional way to specify the cuda install path to be used by the compiler.
# This could possibly affect the cuda version compiled with, but this should
# already be included in the platform information (and might not be reflected
diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py
index 8ff52bd2f559..a2f137686dae 100644
--- a/jax/_src/cloud_tpu_init.py
+++ b/jax/_src/cloud_tpu_init.py
@@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
- if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']:
+ if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''):
os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true'
# this makes tensorstore serialization work better on TPU
diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py
index 2fb13fde72cf..69ef77a6421d 100644
--- a/jax/_src/clusters/cluster.py
+++ b/jax/_src/clusters/cluster.py
@@ -49,12 +49,6 @@ def auto_detect_unset_distributed_params(cls,
initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:
-
- if all(p is not None for p in (coordinator_address, num_processes,
- process_id, local_device_ids)):
- return (coordinator_address, num_processes, process_id,
- local_device_ids)
-
# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if cluster_detection_method is not None:
diff --git a/jax/_src/config.py b/jax/_src/config.py
index 72f394dba76f..2723b4f90d3b 100644
--- a/jax/_src/config.py
+++ b/jax/_src/config.py
@@ -209,7 +209,9 @@ def trace_context():
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
- return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value,
+ return (axis_env_state.value, mesh_context_manager.value,
+ xla_metadata_context_manager.value,
+ abstract_mesh_context_manager.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
@@ -219,6 +221,7 @@ def trace_context():
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
+ use_direct_linearize.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
@@ -263,6 +266,7 @@ def trace_context():
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
+ use_direct_linearize.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
@@ -967,6 +971,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
trace_state = config_ext.Config(None, include_in_jit_key=True)
axis_env_state = config_ext.Config((), include_in_jit_key=True)
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
+ abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
else:
@@ -983,6 +988,7 @@ class _GlobalExtraJitContext(NamedTuple):
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
sharding_in_types: bool = False
+ use_direct_linearize: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
@@ -1025,6 +1031,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
sharding_in_types: bool | None = None
+ use_direct_linearize: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
pgle_profiling_runs: int | None = None
@@ -1318,6 +1325,12 @@ def _update_jax_memories_thread_local(val):
'avals have sharding on them.'),
include_in_jit_key=True)
+use_direct_linearize = bool_state(
+ name='jax_use_direct_linearize',
+ default=False,
+ help=('Use direct linearization instead JVP followed by partial eval'),
+ include_in_jit_key=True)
+
data_dependent_tracing_fallback = bool_state(
name='jax_data_dependent_tracing_fallback',
default=False,
@@ -1963,3 +1976,14 @@ def _update_garbage_collection_guard(state, key, val):
),
include_in_jit_key=True,
)
+
+gpu_use_magma = enum_state(
+ name='jax_use_magma',
+ enum_values=['off', 'on', 'auto'],
+ default='auto',
+ help=(
+ 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. '
+ 'See the documentation for lax.linalg.eig for more details about how '
+ 'to use this feature.'
+ ),
+)
diff --git a/jax/_src/core.py b/jax/_src/core.py
index a1fcdac65df0..86646faa980b 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -38,6 +38,7 @@
from jax._src import config
from jax._src import effects
from jax._src import compute_on
+from jax._src import mesh as mesh_lib
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
@@ -1596,6 +1597,23 @@ def _invalid_shape_error(shape: Shape, context: str=""):
return TypeError(msg)
+
+def get_sharding(sharding, ndim):
+ from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
+
+ if sharding is not None:
+ assert len(sharding.spec) == ndim
+ return sharding
+
+ context_mesh = mesh_lib.mesh_context.mesh
+ # TODO(yashkatariya): Error out and ask users to set the context mesh in their
+ # code.
+ if context_mesh is None:
+ return None
+ assert sharding is None
+ return NamedSharding(context_mesh, P(*[None] * ndim))
+
+
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding'] # inherits slots from parent
array_abstraction_level = 2
@@ -1605,20 +1623,18 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None):
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
if config.sharding_in_types.value:
- if sharding is not None:
- assert len(sharding.spec) == len(self.shape)
- self.sharding = sharding
+ self.sharding = get_sharding(sharding, len(self.shape))
- def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
+ def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
- if sharding is None:
- sharding = getattr(self, 'sharding', None)
- return ShapedArray(shape, dtype, weak_type, sharding=sharding)
+ if 'sharding' not in kwargs:
+ kwargs['sharding'] = getattr(self, 'sharding', None)
+ return ShapedArray(shape, dtype, weak_type, **kwargs)
ndim = property(lambda self: len(self.shape))
size = property(lambda self:
@@ -1704,7 +1720,7 @@ def _get_abstract_sharding(val):
if (config.sharding_in_types.value and hasattr(val, 'sharding') and
isinstance(val.sharding, NamedSharding)):
return NamedSharding(val.sharding.mesh.abstract_mesh,
- val.sharding._normalized_spec(val.ndim))
+ val.sharding.spec._normalized_spec(val.ndim))
return None
def primal_dtype_to_tangent_dtype(primal_dtype):
@@ -2047,6 +2063,70 @@ def dimension_as_value(d: DimSize):
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)
+def canonicalize_slice(
+ s: slice,
+ axis_size: DimSize
+ ) -> tuple[DimSize, DimSize, DimSize]:
+ """Computes the start index, step, and size of the slice `x[s]`.
+
+ This is similar to `s.indices(axis_size)`, except that it returns
+ `(start, step, size)`, and it works when the slice and/or the
+ `axis_size` are symbolic.
+
+ See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
+ """
+ def convert_to_index(d: DimSize) -> DimSize:
+ # Convert np.array and jax.Array to int, leave symbolic dimensions alone
+ try:
+ return operator.index(d)
+ except:
+ return d
+
+ # Must resolve statically if step is {<0, ==0, >0}
+ step = convert_to_index(s.step) if s.step is not None else 1
+ try:
+ if step == 0:
+ raise ValueError("slice step cannot be zero")
+ step_gt_0 = (step > 0)
+ except InconclusiveDimensionOperation as e:
+ raise InconclusiveDimensionOperation(
+ f"In slice with non-constant elements the step ({step}) must " +
+ f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
+
+ def clamp_index(i: DimSize, which: str):
+ try:
+ i_ge_0 = (i >= 0)
+ except InconclusiveDimensionOperation as e:
+ raise InconclusiveDimensionOperation(
+ f"In slice with non-constant elements the {which} ({i}) must " +
+ f"be resolved statically if it is >= 0.\nDetails: {e}")
+ if i_ge_0:
+ if step_gt_0:
+ return min_dim(axis_size, i)
+ else:
+ return min_dim(axis_size - 1, i)
+ else:
+ if step_gt_0:
+ return max_dim(0, axis_size + i)
+ else:
+ return max_dim(-1, axis_size + i)
+
+ if s.start is None:
+ start = 0 if step_gt_0 else axis_size - 1
+ else:
+ start = clamp_index(convert_to_index(s.start), "start")
+
+ if s.stop is None:
+ stop = axis_size if step_gt_0 else -1
+ else:
+ stop = clamp_index(convert_to_index(s.stop), "stop")
+
+ gap = step if step_gt_0 else - step
+ distance = (stop - start) if step_gt_0 else (start - stop)
+ slice_size = max_dim(0, distance + gap - 1) // gap
+ return start, step, slice_size
+
+
class SomeTracer:
__slots__ = ()
def __repr__(self): return "[dynamic]"
@@ -2183,16 +2263,20 @@ def _map_shaped_array(
assert axis is None or aval.shape[axis] == size
# TODO: Extend the named shape
if axis is None: return aval
+ sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
+ if config.sharding_in_types.value else None)
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
- weak_type=aval.weak_type)
+ weak_type=aval.weak_type, sharding=sharding)
def _unmap_shaped_array(
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
) -> ShapedArray:
if axis is None: return aval
elif type(axis) is int:
+ sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name))
+ if config.sharding_in_types.value else None)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
- weak_type=aval.weak_type)
+ weak_type=aval.weak_type, sharding=sharding)
else: raise TypeError(axis)
def _map_dshaped_array(
diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py
index 5b9130fc0455..f80f90bde186 100644
--- a/jax/_src/distributed.py
+++ b/jax/_src/distributed.py
@@ -27,6 +27,13 @@
logger = logging.getLogger(__name__)
+_CHECK_PROXY_ENVS = config.bool_flag(
+ name="jax_check_proxy_envs",
+ default=True,
+ help="Checks proxy vars in user envs and emit warnings.",
+)
+
+
class State:
process_id: int = 0
num_processes: int = 1
@@ -55,16 +62,17 @@ def initialize(self,
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
local_device_ids = list(map(int, env_ids.split(",")))
- (coordinator_address, num_processes, process_id, local_device_ids) = (
- clusters.ClusterEnv.auto_detect_unset_distributed_params(
- coordinator_address,
- num_processes,
- process_id,
- local_device_ids,
- cluster_detection_method,
- initialization_timeout,
- )
- )
+ if None in (coordinator_address, num_processes, process_id, local_device_ids):
+ (coordinator_address, num_processes, process_id, local_device_ids) = (
+ clusters.ClusterEnv.auto_detect_unset_distributed_params(
+ coordinator_address,
+ num_processes,
+ process_id,
+ local_device_ids,
+ cluster_detection_method,
+ initialization_timeout,
+ )
+ )
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
@@ -92,8 +100,10 @@ def initialize(self,
self.process_id = process_id
- # Emit a warning about PROXY variables if they are in the user's env:
- proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]
+ proxy_vars = []
+ if _CHECK_PROXY_ENVS.value:
+ proxy_vars = [key for key in os.environ.keys()
+ if '_proxy' in key.lower()]
if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
@@ -179,7 +189,9 @@ def initialize(coordinator_address: str | None = None,
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.
Otherwise, you must provide the ``coordinator_address``,
- ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
+ ``num_processes``, ``process_id``, and ``local_device_ids`` arguments
+ to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster
+ environment auto detection will be skipped.
Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py
index f5b0c3fd68b1..b53e1777f6a9 100644
--- a/jax/_src/dtypes.py
+++ b/jax/_src/dtypes.py
@@ -90,12 +90,17 @@ def type(self) -> type: ...
# fp8 support
+# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
+float8_e3m4: type[np.generic] | None = None
+float8_e4m3: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
+_float8_e3m4_dtype: np.dtype | None = None
+_float8_e4m3_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
@@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e5m2fnuz_dtype,
]
+# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
+if hasattr(ml_dtypes, "float8_e4m3"):
+ float8_e4m3 = ml_dtypes.float8_e4m3
+ _float8_e4m3_dtype = np.dtype(float8_e4m3)
+ _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
+ _custom_float_dtypes.insert(0, _float8_e4m3_dtype)
+ _float8_dtypes.insert(0, _float8_e4m3_dtype)
+if hasattr(ml_dtypes, "float8_e3m4"):
+ float8_e3m4 = ml_dtypes.float8_e3m4
+ _float8_e3m4_dtype = np.dtype(float8_e3m4)
+ _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
+ _custom_float_dtypes.insert(0, _float8_e3m4_dtype)
+ _float8_dtypes.insert(0, _float8_e3m4_dtype)
+
# 2-bit integer support
int2: type[np.generic] | None = None
uint2: type[np.generic] | None = None
@@ -343,6 +362,7 @@ def _issubclass(a: Any, b: Any) -> bool:
# TODO(jakevdp): consider whether to disallow None here. We allow it
# because np.issubdtype allows it (and treats it as equivalent to float64).
+@set_module('jax.numpy')
def issubdtype(a: DTypeLike | ExtendedDType | None,
b: DTypeLike | ExtendedDType | None) -> bool:
"""Returns True if first argument is a typecode lower/equal in type hierarchy.
@@ -458,6 +478,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
}
+@set_module('jax.numpy')
def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool:
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
@@ -650,6 +671,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
"JAX's internal logic; please report it to the JAX maintainers."
)
+@set_module('jax.numpy')
def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
"""Returns the type to which a binary operation should cast its arguments.
diff --git a/jax/_src/errors.py b/jax/_src/errors.py
index 590f68ac0b3b..6540fd1f5d41 100644
--- a/jax/_src/errors.py
+++ b/jax/_src/errors.py
@@ -677,7 +677,7 @@ class KeyReuseError(JAXTypeError):
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
This sort of key reuse is problematic because the JAX PRNG is stateless, and keys
- must be manually split; For more information on this see `Sharp Bits: Random Numbers
- `_.
+ must be manually split; For more information on this see `the Pseudorandom Numbers
+ tutorial `_.
"""
pass
diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs
index 3198f83aa120..b71b377d8999 100644
--- a/jax/_src/export/serialization.fbs
+++ b/jax/_src/export/serialization.fbs
@@ -67,6 +67,8 @@ enum DType: byte {
i4 = 15,
ui4 = 16,
+ f8_e3m4 = 24,
+ f8_e4m3 = 23,
f8_e4m3b11fnuz = 17,
f8_e4m3fn = 18,
f8_e4m3fnuz = 19,
diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py
index e392289da64d..0d9ce961b556 100644
--- a/jax/_src/export/serialization.py
+++ b/jax/_src/export/serialization.py
@@ -359,6 +359,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
}
+if dtypes._float8_e3m4_dtype is not None:
+ _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
+if dtypes._float8_e4m3_dtype is not None:
+ _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py
index 18dd2c3cbab1..70d298020961 100644
--- a/jax/_src/export/serialization_generated.py
+++ b/jax/_src/export/serialization_generated.py
@@ -53,6 +53,8 @@ class DType(object):
bf16 = 14
i4 = 15
ui4 = 16
+ f8_e3m4 = 24
+ f8_e4m3 = 23
f8_e4m3b11fnuz = 17
f8_e4m3fn = 18
f8_e4m3fnuz = 19
diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py
index 99340e728545..9fa2fdb9ffbf 100644
--- a/jax/_src/interpreters/ad.py
+++ b/jax/_src/interpreters/ad.py
@@ -39,7 +39,6 @@
as_hashable_function, weakref_lru_cache,
partition_list)
-
zip = safe_zip
map = safe_map
def identity(x): return x
@@ -106,7 +105,29 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
store.store(aux_primals)
return out_primals, out_tangents
+def direct_linearize(traceable, *primals, **kwargs):
+ has_aux = kwargs.pop('has_aux', False)
+ assert not has_aux
+ with core.take_current_trace() as parent_trace:
+ frame = pe.JaxprStackFrame()
+ tangent_trace = pe.DynamicJaxprTrace(frame)
+ tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
+ tag = core.TraceTag()
+ linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag)
+ tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
+ with core.set_current_trace(linearize_trace):
+ ans = traceable.call_wrapped(*tracers)
+
+ out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
+ out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
+ jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents)
+ out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
+ del attrs_tracked # TODO: attrs
+ return out_primals, out_tangents_pvals, jaxpr, consts
+
def linearize(traceable, *primals, **kwargs):
+ if config.use_direct_linearize.value:
+ return direct_linearize(traceable, *primals, **kwargs)
has_aux = kwargs.pop('has_aux', False)
if not has_aux:
jvpfun = jvp(traceable)
@@ -444,15 +465,94 @@ def _primal_tangent_shapes_match(primal, tangent):
call_param_updaters: dict[core.Primitive, Callable] = {}
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
+# -------------------- Linearize trace --------------------
+
+class LinearizeTrace(Trace):
+
+ def __init__(self, parent_trace, tangent_trace, tag):
+ self.tag = tag
+ self.parent_trace = parent_trace
+ self.tangent_trace = tangent_trace
+
+ def to_primal_tangent_pair(self, val):
+ if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
+ return (val.primal, val.tangent)
+ else:
+ tangent_zero = Zero.from_primal_value(val)
+ return (val, tangent_zero)
+
+ def process_primitive(self, primitive, args, params):
+ primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
+ tangent_nonzeros = [type(t) is not Zero for t in tangents_in]
+ if all(type(t) is Zero for t in tangents_in):
+ return primitive.bind_with_trace(self.parent_trace, primals_in, params)
+ lin = primitive_linearizations.get(primitive)
+ if lin is None:
+ lin = partial(fallback_linearize_rule, primitive)
+ with core.set_current_trace(self.parent_trace):
+ primal_out, tangent_nonzeros_out, residuals, linearized = lin(
+ tangent_nonzeros, *primals_in, **params)
+ with core.set_current_trace(self.tangent_trace):
+ tangent_out = linearized(residuals, *tangents_in)
+ if primitive.multiple_results:
+ return [maybe_linearize_tracer(self, x, nz, t)
+ for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)]
+ else:
+ return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out)
+
+def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
+ if is_nonzero:
+ assert not type(tangent) is Zero
+ return LinearizeTracer(trace, primal, tangent)
+ else:
+ assert type(tangent) is Zero
+ return primal
+
+def fallback_linearize_rule(prim, _, *args, **kwargs):
+ def call_prim(*args_):
+ return prim.bind(*args_, **kwargs)
+ with config.use_direct_linearize(False):
+ out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
+ lu.wrap_init(call_prim), *args, **kwargs)
+ def linearized(residuals, *tangents):
+ tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents))
+ full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
+ for pval in out_tangents_pvals]
+ assert next(tangents_out, None) is None
+ return full_out
+ return out_primals, [True for _ in out_primals], consts, linearized
+
+class LinearizeTracer(Tracer):
+ __slots__ = ['primal', 'tangent']
+
+ def __init__(self, trace, primal, tangent):
+ if config.enable_checks.value:
+ _primal_tangent_shapes_match(primal, tangent)
+ self._trace = trace
+ self.primal = primal
+ self.tangent = tangent
+
+ @property
+ def aval(self):
+ return get_aval(self.primal)
+
+ def full_lower(self):
+ if type(self.tangent) is Zero:
+ return core.full_lower(self.primal)
+ else:
+ return self
+
+ def to_concrete_value(self):
+ return core.to_concrete_value(self.primal)
+
# -------------------- Primitives --------------------
primitive_jvps : dict[core.Primitive, Callable] = {}
-
primitive_transposes: dict[core.Primitive, Callable] = {}
# transpose rules that internally perform reductions over the given named axes
reducing_transposes: dict[core.Primitive, Callable] = {}
-
+primitive_linearizations : dict[core.Primitive, Callable] = {}
def deflinear(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py
index bef465c6aa75..102e4f490b5c 100644
--- a/jax/_src/interpreters/mlir.py
+++ b/jax/_src/interpreters/mlir.py
@@ -184,13 +184,13 @@ def _is_ir_values(x: IrValues) -> bool:
if dtypes.int2 is not None:
assert dtypes.uint2 is not None
- _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
- ir.IntegerType.get_signless, 2
- )
- _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(
- ir.IntegerType.get_unsigned, 2
- )
+ _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
+ _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)
+if dtypes.float8_e3m4 is not None:
+ _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
+if dtypes.float8_e4m3 is not None:
+ _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
@@ -1752,6 +1752,44 @@ def _emit_lowering_rule_as_fun(lowering_rule,
return func_op
+class HashableLiteral:
+ """Hashable wrapper of core.Literal, used for deduplicating IR constants."""
+
+ __slots__ = ["value", "data"]
+
+ value: core.Literal
+
+ # Copy of the value suitable for an equality comparison. We are careful to
+ # avoid floating point comparisons here, because in particular we don't want
+ # 0.0 and -0.0 to be considered equal, but we are fine with NaNs being equal.
+ data: bytes | int | bool | None
+
+ def __init__(self, value):
+ self.value = value
+ if isinstance(value.val, (np.generic, np.ndarray)):
+ self.data = value.val.tobytes()
+ elif isinstance(value.val, (bool, int)):
+ self.data = value.val
+ elif isinstance(value.val, float):
+ self.data = np.float64(value.val).tobytes()
+ elif isinstance(value.val, complex):
+ self.data = np.complex128(value.val).tobytes()
+ else:
+ self.data = None # Unhandled case.
+
+ def __hash__(self):
+ return hash(self.data)
+
+ def __eq__(self, other):
+ if type(self.value.val) != type(other.value.val):
+ return False
+ if self.value.aval != other.value.aval:
+ return False
+ if self.data is None:
+ return id(self) == id(other)
+ return self.data == other.data
+
+
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
name_stack: source_info_util.NameStack,
tokens: TokenSet,
@@ -1767,9 +1805,16 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert "gpu" not in ctx.platforms
+ cached_ir_consts: dict[HashableLiteral, IrValues] = {}
+
def read(v: core.Atom) -> IrValues:
if type(v) is core.Literal:
- return ir_constant(xla.canonicalize_dtype(v.val))
+ h = HashableLiteral(v)
+ c = cached_ir_consts.get(h)
+ if c is None:
+ c = ir_constant(xla.canonicalize_dtype(v.val))
+ cached_ir_consts[h] = c
+ return c
else:
assert isinstance(v, core.Var)
return env[v]
@@ -2474,6 +2519,18 @@ def _wrap_with_spmd_op(name: str,
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
+def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
+ # Don't emit a wsc under full manual mode to avoid increasing HLO size.
+ if aval.sharding.mesh._are_all_axes_collective:
+ return op
+ proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
+ if sharding_proto is None else sharding_proto)
+ # TODO(yashkatariya): Enable this
+ # unspecified_dims = (set(range(aval.ndim))
+ # if aval.sharding.mesh._any_axis_collective else None)
+ return wrap_with_sharding_op(ctx, op, aval, proto)
+
+
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
if config.use_shardy_partitioner.value:
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index 6c9e54441f8e..2164c1a914c9 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -2747,11 +2747,11 @@ def _maybe_get_and_check_out_shardings(
return new_out_shardings
-def finalize_out_shardings(out_shardings, device_assignment):
+def finalize_shardings(shardings, device_assignment):
if len(device_assignment) == 1:
return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
- if isinstance(o, GSPMDSharding) else o for o in out_shardings]
- return out_shardings
+ if isinstance(o, GSPMDSharding) else o for o in shardings]
+ return shardings
@dataclasses.dataclass
@@ -2892,7 +2892,8 @@ def from_hlo(name: str,
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings, context_mesh)
- out_shardings = finalize_out_shardings(out_shardings, da)
+ in_shardings = finalize_shardings(in_shardings, da)
+ out_shardings = finalize_shardings(out_shardings, da)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py
index d15917b8b1da..76132ccdc99a 100644
--- a/jax/_src/lax/control_flow/loops.py
+++ b/jax/_src/lax/control_flow/loops.py
@@ -227,6 +227,11 @@ def scan(f, init, xs, length=None):
msg.format(', '.join(str(x) for x in xs_flat
if not hasattr(x, 'shape')))) from err
+ if (config.sharding_in_types.value and
+ not all(x.sharding.spec[0] is None for x in xs_flat)):
+ raise ValueError('0th dimension of all xs should be replicated. Got '
+ f'{", ".join(str(x.sharding.spec) for x in xs_flat)}')
+
if length is not None:
try:
length = int(length)
@@ -250,7 +255,8 @@ def scan(f, init, xs, length=None):
if config.disable_jit.value:
if length == 0:
- raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
+ raise ValueError("zero-length scan is not supported in disable_jit() "
+ "mode because the output type is unknown.")
carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
@@ -424,7 +430,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
num_trips, remainder = 0, length
if unroll == 1:
xss = xs_
- yss = _map(partial(_empty_array, (length,)), y_avals)
+ yss = _map(partial(_empty_array, (length,), None), y_avals)
else:
if remainder:
if not reverse:
@@ -432,7 +438,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
else:
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
- yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)
+ yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals)
def cond_fun(while_carry):
i, _, _ = while_carry
@@ -477,8 +483,11 @@ def _split_leading(sz, x):
def _concat(a, b): return lax.concatenate([a, b], 0)
-def _empty_array(prefix, aval):
- return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape))
+def _empty_array(prefix, length_spec, aval):
+ sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec))
+ if config.sharding_in_types.value else None)
+ return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
+ sharding=sharding)
eval_jaxpr_p = core.Primitive('eval_jaxpr')
eval_jaxpr_p.multiple_results = True
@@ -486,11 +495,13 @@ def _stage_jaxpr(trace, *tracers, jaxpr):
params = dict(call_jaxpr=jaxpr)
return trace.default_process_primitive(core.closed_call_p, tracers, params)
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
+
@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf
-def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects
+def _stage_jaxpr_abstract_eval(*_, jaxpr):
+ return jaxpr.out_avals, jaxpr.effects
def _prepend_dim_to_aval(sz, aval):
- return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
+ return core.unmapped_aval(sz, None, 0, aval)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll, _split_transpose):
@@ -674,7 +685,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
extensive_res = _map(trace.new_instantiated_const, extensive_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
- ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval)
+ ys_avals = [core.unmapped_aval(length, None, 0, y_aval)
for y_aval in y_avals]
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in itertools.chain(carry_avals, ys_avals)]
@@ -1041,7 +1052,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
- ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
+ ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
@@ -1119,7 +1130,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
jaxpr.in_avals, [num_consts, num_carry])
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
- y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a)
+ y_avals = [core.unmapped_aval(length, None, 0, a)
for a in y_avals_mapped]
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py
index c45d8f5c80b2..1b84797d630e 100644
--- a/jax/_src/lax/lax.py
+++ b/jax/_src/lax/lax.py
@@ -879,11 +879,11 @@ def __str__(self) -> str:
def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
match self:
case (
- DotAlgorithmPreset.DEFAULT |
- DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
- DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
- DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
- DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
+ DotAlgorithmPreset.DEFAULT
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_F32
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
return None
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
@@ -906,14 +906,26 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
return self.lhs_precision_type
@property
- def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
+ def accumulation_type(self) -> DTypeLike | None:
match self:
case (
- DotAlgorithmPreset.DEFAULT |
- DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
- DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
+ DotAlgorithmPreset.DEFAULT
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
return None
+ case DotAlgorithmPreset.F16_F16_F16:
+ return np.float16
+ case DotAlgorithmPreset.BF16_BF16_BF16:
+ return dtypes.bfloat16
+ case DotAlgorithmPreset.F64_F64_F64:
+ return np.float64
+ case _:
+ return np.float32
+
+ @property
+ def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
+ match self:
case (
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
@@ -921,16 +933,11 @@ def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
- case DotAlgorithmPreset.F16_F16_F16:
- return np.float16
case DotAlgorithmPreset.F16_F16_F32:
return (np.float32, np.float16)
- case DotAlgorithmPreset.BF16_BF16_BF16:
- return dtypes.bfloat16
- case DotAlgorithmPreset.F64_F64_F64:
- return np.float64
case _:
- return np.float32
+ accumulation_type = self.accumulation_type
+ return None if accumulation_type is None else (accumulation_type,)
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
@@ -941,26 +948,39 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
tf32 = ir.FloatTF32Type.get()
match self:
case (
- DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
- DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
- DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
- DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
+ DotAlgorithmPreset.ANY_F8_ANY_F8_F32
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
+ | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
- fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
- np.dtype(dtypes.float8_e4m3fn),
- np.dtype(dtypes.float8_e4m3fnuz),
- np.dtype(dtypes.float8_e5m2),
- np.dtype(dtypes.float8_e5m2fnuz))
+ fp8_dtypes = [
+ np.dtype(dtypes.float8_e4m3b11fnuz),
+ np.dtype(dtypes.float8_e4m3fn),
+ np.dtype(dtypes.float8_e4m3fnuz),
+ np.dtype(dtypes.float8_e5m2),
+ np.dtype(dtypes.float8_e5m2fnuz),
+ ]
+ if dtypes.float8_e3m4 is not None:
+ fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
+ if dtypes.float8_e4m3 is not None:
+ fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
- f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.")
+ f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.'
+ )
lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype))
rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype))
acc = ir.F32Type.get()
return hlo.DotAlgorithm.get(
- lhs, rhs, acc, 1, 1, 1,
- self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
+ lhs,
+ rhs,
+ acc,
+ 1,
+ 1,
+ 1,
+ self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
+ )
case DotAlgorithmPreset.F16_F16_F16:
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
case DotAlgorithmPreset.F16_F16_F32:
@@ -1286,6 +1306,36 @@ def pad(operand: ArrayLike, padding_value: ArrayLike,
Returns:
The ``operand`` array with padding value ``padding_value`` inserted in each
dimension according to the ``padding_config``.
+
+ Examples:
+ >>> from jax import lax
+ >>> import jax.numpy as jnp
+
+ Pad a 1-dimensional array with zeros, We'll specify two zeros in front and
+ three at the end:
+
+ >>> x = jnp.array([1, 2, 3, 4])
+ >>> lax.pad(x, 0, [(2, 3, 0)])
+ Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
+
+ Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero
+ between each value:
+
+ >>> lax.pad(x, 0, [(0, 0, 1)])
+ Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
+
+ Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad
+ size of 2 in each dimension:
+
+ >>> x = jnp.array([[1, 2, 3],
+ ... [4, 5, 6]])
+ >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
+ Array([[-1, -1, -1, -1, -1, -1, -1],
+ [-1, -1, -1, -1, -1, -1, -1],
+ [-1, -1, 1, 2, 3, -1, -1],
+ [-1, -1, 4, 5, 6, -1, -1],
+ [-1, -1, -1, -1, -1, -1, -1],
+ [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
"""
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
@@ -1650,6 +1700,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
scalar_zero = np.zeros((), dtype=aval.dtype)
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
+ if config.sharding_in_types.value:
+ return broadcast(scalar_zero, aval.shape, sharding=aval.sharding)
return broadcast(scalar_zero, aval.shape)
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
@@ -2092,11 +2144,11 @@ def broadcasting_sharding_rule(name, *avals):
mesh = None
for a in avals:
if a.sharding is not None:
- mesh = a.sharding.mesh
if mesh is not None and mesh != a.sharding.mesh:
raise ValueError(
f'Mesh for all inputs should be equal. Got one mesh: {mesh} and'
f' another mesh: {a.sharding.mesh}')
+ mesh = a.sharding.mesh
assert mesh is not None
shapes = [aval.shape for aval in avals if aval.shape]
@@ -2203,14 +2255,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
for op, in_aval in zip(ops, in_avals):
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
out.append(op)
- elif in_aval.sharding.mesh.are_all_axes_collective:
- out.append(op)
else:
- # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains
- # CompilerShardingAxis, then specify `unspecified_dims` via
- # `wrap_with_sharding_op`.
- sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
- out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp))
+ proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
+ out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto))
return out
@@ -2226,10 +2273,7 @@ def _nary_lower_hlo(op: Callable, ctx,
out = op(*args)
if config.sharding_in_types.value:
- if aval_out.sharding.mesh.are_all_axes_collective:
- return [out]
- out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
else:
return [out]
@@ -2356,12 +2400,16 @@ def _sin_lowering(ctx, x):
return sine(ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x)
+def _sin_p_lin(_, x):
+ cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
+ return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_))
+
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
+ad.primitive_linearizations[sin_p] = _sin_p_lin
mlir.register_lowering(sin_p, _sin_lowering)
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
-
def _cos_complex(x):
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
# see also _sin_complex
@@ -2582,15 +2630,12 @@ def _pow_jvp_rhs(g, ans, x, y):
def _pow_lower(ctx, x, y):
x_aval, y_aval = ctx.avals_in
- out_aval, = ctx.avals_out
- convert = mlir.lower_fun(
- partial(convert_element_type, new_dtype=out_aval.dtype), False)
- x_aval_ = x_aval.update(dtype=out_aval.dtype)
- y_aval_ = y_aval.update(dtype=out_aval.dtype)
- [x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
- [y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
- ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
- return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
+ if x_aval.dtype != y_aval.dtype:
+ out_aval, = ctx.avals_out
+ y_aval = y_aval.update(dtype=out_aval.dtype)
+ y = hlo.convert(mlir.aval_to_ir_type(y_aval), y)
+ ctx = ctx.replace(avals_in=[x_aval, y_aval])
+ return _nary_lower_hlo(hlo.power, ctx, x, y)
mlir.register_lowering(pow_p, _pow_lower)
def _integer_pow_dtype_rule(x, *, y):
@@ -2646,8 +2691,7 @@ def _integer_pow_lowering(ctx, x, *, y):
out, = lowering(ctx, x, y=y)
if config.sharding_in_types.value:
aval_out, = ctx.avals_out
- proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
@@ -3029,8 +3073,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
if config.sharding_in_types.value:
if sharding is not None:
assert aval_out.sharding == sharding
- proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
@@ -3658,9 +3701,8 @@ def maybe_convert_dtype(input_dtype, target_dtype):
return input_dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
- if any(input_dtype == d for d in target_dtype):
- return input_dtype
- return target_dtype[0]
+ return input_dtype if input_dtype in target_dtype else target_dtype[0]
+
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
@@ -3671,10 +3713,15 @@ def maybe_convert_dtype(input_dtype, target_dtype):
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
return lhs_dtype, rhs_dtype, out_dtype
else:
+ if isinstance(algorithm, DotAlgorithmPreset):
+ supported_output_types = algorithm.supported_output_types
+ else:
+ supported_output_types = (algorithm.accumulation_type,)
+
return (
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
- maybe_convert_dtype(out_dtype, algorithm.accumulation_type),
+ maybe_convert_dtype(out_dtype, supported_output_types),
)
@@ -3684,6 +3731,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
+ if dtypes.float8_e3m4 is not None:
+ fp8_dtypes += (dtypes.float8_e3m4,)
+ if dtypes.float8_e4m3 is not None:
+ fp8_dtypes += (dtypes.float8_e4m3,)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in
@@ -3765,8 +3816,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
if config.sharding_in_types.value:
if out_type is not None:
assert aval_out.sharding == out_type
- out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp)
+ result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
if accumulation_aval.dtype != aval_out.dtype:
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
return [result]
@@ -4231,8 +4281,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
if config.sharding_in_types.value:
if sharding is not None:
assert sharding == aval_out.sharding
- proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
@@ -4358,7 +4407,7 @@ def _concatenate_shape_rule(*operands, **kwargs):
raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))
shapes = [operand.shape[:dimension] + operand.shape[dimension+1:]
for operand in operands]
- if not shapes[:-1] == shapes[1:]:
+ if shapes[:-1] != shapes[1:]:
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
"other than the one being concatenated: concatenating along "
"dimension {} for shapes {}.")
@@ -4369,6 +4418,13 @@ def _concatenate_shape_rule(*operands, **kwargs):
ex_shape = operands[0].shape
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
+def _concatenate_sharding_rule(*operands, **kwargs):
+ if not all(o.sharding == operands[0].sharding for o in operands):
+ ss = ", ".join(str(o.sharding) for o in operands)
+ raise TypeError(
+ f"All operands should have the same sharding. Got shardings {ss}")
+ return operands[0].sharding
+
def _concatenate_dtype_rule(*operands, **kwargs):
check_same_dtypes('concatenate', *operands)
return operands[0].dtype
@@ -4409,14 +4465,19 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension):
raise NotImplementedError # TODO(mattjj)
concatenate_p = standard_primitive(
- _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate')
+ _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
+ sharding_rule=_concatenate_sharding_rule)
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
def _concatenate_lower(ctx, *xs, dimension):
- return [hlo.concatenate(xs, mlir.i64_attr(dimension))]
+ aval_out, = ctx.avals_out
+ out = hlo.concatenate(xs, mlir.i64_attr(dimension))
+ if config.sharding_in_types.value:
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
+ return [out]
mlir.register_lowering(concatenate_p, _concatenate_lower)
@@ -4428,7 +4489,8 @@ def _pad_dtype_rule(operand, padding_value, *, padding_config):
return _input_dtype(operand, padding_value)
def _pad_shape_rule(operand, padding_value, *, padding_config):
- del padding_value
+ if np.ndim(padding_value) != 0:
+ raise ValueError(f"padding_value must be a scalar; got {np.shape(padding_value)=}")
op_shape = np.shape(operand)
if not len(padding_config) == np.ndim(operand):
raise ValueError("length of padding_config must equal the number of axes "
@@ -4446,6 +4508,15 @@ def _pad_shape_rule(operand, padding_value, *, padding_config):
raise ValueError(msg)
return result
+def _pad_sharding_rule(operand, padding_value, *, padding_config):
+ # TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
+ # change this logic to `return operand.sharding` directly.
+ out_shape = _pad_shape_rule(operand, padding_value,
+ padding_config=padding_config)
+ return slicing._get_sharding_for_varying_out_shape(
+ out_shape, operand, 'padding')
+
+
def _pad_transpose(t, operand, padding_value, *, padding_config):
if type(t) is ad_util.Zero:
t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
@@ -4485,14 +4556,18 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
(operand_bdim,))
return select(mask, x, broadcasted_padding), operand_bdim
-pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad')
+pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
+ sharding_rule=_pad_sharding_rule)
ad.deflinear2(pad_p, _pad_transpose)
batching.primitive_batchers[pad_p] = _pad_batch_rule
def _pad_lower(ctx, x, padding_value, *, padding_config):
aval_out, = ctx.avals_out
low, high, interior = util.unzip3(padding_config)
- return [mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)]
+ out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)
+ if config.sharding_in_types.value:
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
+ return [out]
mlir.register_lowering(pad_p, _pad_lower)
@@ -4514,6 +4589,12 @@ def _squeeze_dtype_rule(operand, *, dimensions):
def _squeeze_shape_rule(operand, *, dimensions):
return _compute_squeeze_shape(np.shape(operand), dimensions)
+def _squeeze_sharding_rule(operand, *, dimensions):
+ dims_set = set(dimensions)
+ new_spec = tuple(s for i, s in enumerate(operand.sharding.spec)
+ if i not in dims_set)
+ return NamedSharding(operand.sharding.mesh, P(*new_spec))
+
def _compute_squeeze_shape(shape, dimensions):
dims_set = set(dimensions)
if len(dims_set) != len(dimensions):
@@ -4542,7 +4623,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
return squeeze(operand, dimensions=dimensions), bdim_out
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
- 'squeeze')
+ 'squeeze', sharding_rule=_squeeze_sharding_rule)
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
pe.def_trivial_padding(squeeze_p)
@@ -4550,7 +4631,11 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
- return [mlir.reshape(ctx, operand, ctx.avals_out[0])]
+ aval_out, = ctx.avals_out
+ out = mlir.reshape(ctx, operand, aval_out)
+ if config.sharding_in_types.value:
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
+ return [out]
mlir.register_lowering(squeeze_p, _squeeze_lower)
@@ -4559,6 +4644,8 @@ def shape_as_value(shape: core.Shape):
"""Converts a shape that may contain Poly values into a JAX value."""
if len(shape) == 0:
return full((0,), np.array(0, np.int64))
+ if core.is_constant_shape(shape):
+ return np.asarray(shape, dtype=np.int64)
dims = [
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
(0,))
@@ -4645,8 +4732,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
out = mlir.reshape(ctx, x, aval_out)
if config.sharding_in_types.value:
- proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
def _reshape_staging_rule(
@@ -4726,8 +4812,7 @@ def _transpose_lower(ctx, x, *, permutation):
permutation = [*permutation, *trailing_dims]
out = hlo.transpose(x, mlir.dense_int_array(permutation))
if config.sharding_in_types.value:
- proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
transpose_p = standard_primitive(
@@ -4868,8 +4953,7 @@ def _select_hlo_lowering_opaque(ctx, which, *cases):
def _add_shit_to_select(ctx, op, aval_out):
if config.sharding_in_types.value:
- proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return mlir.wrap_with_sharding_op(ctx, op, aval_out, proto)
+ return mlir.lower_sharding_under_shit(ctx, op, aval_out)
return op
def _select_hlo_lowering(ctx, which, *cases):
@@ -5241,8 +5325,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
with ir.InsertionPoint(reducer_region):
hlo.return_([reducer(*reducer_region.arguments)])
if config.sharding_in_types.value:
- out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, op.result, aval_out, out_sp)]
+ return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)]
return op.results
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
@@ -5941,8 +6024,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
out = mlir.iota(ctx, aval_out, dimension=dimension)
if config.sharding_in_types.value:
assert aval_out.sharding == sharding
- proto = sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
- return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(iota_p, _iota_lower)
diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py
index 0e0390abc78f..62cb72c69fd7 100644
--- a/jax/_src/lax/linalg.py
+++ b/jax/_src/lax/linalg.py
@@ -121,16 +121,46 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
- compute_right_eigenvectors: bool = True) -> list[Array]:
+ compute_right_eigenvectors: bool = True,
+ use_magma: bool | None = None) -> list[Array]:
"""Eigendecomposition of a general matrix.
- Nonsymmetric eigendecomposition is at present only implemented on CPU.
+ Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
+ the default implementation calls LAPACK directly on the host CPU, but an
+ experimental GPU implementation using `MAGMA `_
+ is also available. The MAGMA implementation is typically slower than the
+ equivalent LAPACK implementation for small matrices (less than about 2048),
+ but it may perform better for larger matrices.
+
+ To enable the MAGMA implementation, you must install MAGMA yourself (there
+ are Debian and conda-forge packages, or you can build from source). Then set
+ the ``use_magma`` argument to ``True``, or set the ``jax_use_magma``
+ configuration variable to ``"on"`` or ``"auto"``:
+
+ .. code-block:: python
+
+ jax.config.update('jax_use_magma', 'on')
+
+ JAX will try to ``dlopen`` the installed MAGMA shared library, raising an
+ error if it is not found. To explicitly specify the path to the MAGMA
+ library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full
+ installation path.
+
+ If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will
+ be used if the library can be found, and the input matrix is sufficiently
+ large (>= 2048x2048).
Args:
x: A batch of square matrices with shape ``[..., n, n]``.
compute_left_eigenvectors: If true, the left eigenvectors will be computed.
compute_right_eigenvectors: If true, the right eigenvectors will be
computed.
+ use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
+ eigendecomposition is computed using MAGMA. If ``False``, the computation
+ is done using LAPACK on to the host CPU. If ``None`` (default), the
+ behavior is controlled by the ``jax_use_magma`` flag. This argument
+ is only used on GPU.
+
Returns:
The eigendecomposition of ``x``, which is a tuple of the form
``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left
@@ -142,7 +172,8 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
for that batch element.
"""
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
- compute_right_eigenvectors=compute_right_eigenvectors)
+ compute_right_eigenvectors=compute_right_eigenvectors,
+ use_magma=use_magma)
def eigh(
@@ -678,12 +709,14 @@ def _symmetric_product_jax_fn(a, c, *, alpha, beta):
# Asymmetric eigendecomposition
-def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
+def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors,
+ use_magma):
return dispatch.apply_primitive(
eig_p,
operand,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
+ use_magma=use_magma,
)
def eig_lower(*args, **kw):
@@ -692,7 +725,8 @@ def eig_lower(*args, **kw):
"If your matrix is symmetric or Hermitian, you should use eigh instead.")
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
- compute_right_eigenvectors):
+ compute_right_eigenvectors, use_magma):
+ del use_magma # unused
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
@@ -716,7 +750,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
return tuple(output)
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
- compute_right_eigenvectors):
+ compute_right_eigenvectors, use_magma):
+ del use_magma # unused
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
@@ -763,18 +798,94 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
return output
+def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors,
+ compute_right_eigenvectors, use_magma):
+ gpu_solver.initialize_hybrid_kernels()
+ dtype = x.dtype
+ is_real = dtype == np.float32 or dtype == np.float64
+ if is_real:
+ target_name = f"{target_name_prefix}hybrid_eig_real"
+ complex_dtype = np.complex64 if dtype == np.float32 else np.complex128
+ else:
+ target_name = f"{target_name_prefix}hybrid_eig_comp"
+ assert dtype == np.complex64 or dtype == np.complex128
+ complex_dtype = dtype
+
+ batch_dims = x.shape[:-2]
+ n, m = x.shape[-2:]
+ assert n == m
+ num_batch_dims = len(batch_dims)
+
+ layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims)
+ out_types = [
+ api.ShapeDtypeStruct(batch_dims + (n,), dtype),
+ api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
+ api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
+ api.ShapeDtypeStruct(batch_dims, np.int32),
+ ]
+ out_layouts = [None, layout, layout, None]
+ if is_real:
+ out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types
+ out_layouts = [None] + out_layouts
+
+ magma = config.gpu_use_magma.value
+ if use_magma is not None:
+ magma = "on" if use_magma else "off"
+ fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout],
+ output_layouts=out_layouts)
+ *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors,
+ right=compute_right_eigenvectors)
+ if is_real:
+ assert len(w) == 2
+ w = lax.complex(*w)
+ else:
+ assert len(w) == 1
+ w = w[0]
+ ok = lax.eq(info, lax.zeros_like_array(info))
+ ok = _broadcast_to(ok[..., None], w.shape)
+ w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j))
+ ok = _broadcast_to(ok[..., None], x.shape)
+ output = [w]
+ if compute_left_eigenvectors:
+ vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j))
+ output.append(vl)
+ if compute_right_eigenvectors:
+ vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j))
+ output.append(vr)
+ return output
+
+
+def _eig_gpu_lowering(target_name_prefix, ctx, operand, *,
+ compute_left_eigenvectors, compute_right_eigenvectors,
+ use_magma):
+ if ctx.is_forward_compat():
+ raise NotImplementedError(
+ "Export of nonsymmetric eigendecomposition on GPU is not supported "
+ "because of forward compatibility. The "
+ "'jax_export_ignore_forward_compatibility' configuration option can be "
+ "used to disable this check.")
+ rule = mlir.lower_fun(partial(
+ _eig_gpu_impl, target_name_prefix,
+ compute_left_eigenvectors=compute_left_eigenvectors,
+ compute_right_eigenvectors=compute_right_eigenvectors,
+ use_magma=use_magma), multiple_results=True)
+ return rule(ctx, operand)
+
+
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
- compute_right_eigenvectors):
+ compute_right_eigenvectors, use_magma):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
- compute_right_eigenvectors=compute_right_eigenvectors),
+ compute_right_eigenvectors=compute_right_eigenvectors,
+ use_magma=use_magma),
(0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors))
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
- compute_right_eigenvectors):
+ compute_right_eigenvectors, use_magma):
+ del use_magma # unused
if compute_left_eigenvectors or compute_right_eigenvectors:
raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only '
@@ -793,6 +904,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
eig_p.def_abstract_eval(eig_abstract_eval)
mlir.register_lowering(eig_p, eig_lower)
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
+mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'),
+ platform='cuda')
+mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'),
+ platform='rocm')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule
diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py
index 3a1c1ef3bcf1..c8cea6a9df5b 100644
--- a/jax/_src/lax/parallel.py
+++ b/jax/_src/lax/parallel.py
@@ -24,9 +24,11 @@
from jax import tree_util
from jax._src import core
+from jax._src import config
from jax._src import dispatch
from jax._src import dtypes
-from jax._src import sharding_impls
+from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext,
+ NamedSharding, PartitionSpec as P)
from jax._src.core import AxisName, ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@@ -635,9 +637,15 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
if len(pos_axes) != 0:
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
- out_avals = [
- ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes),
- arg.dtype) for arg in args]
+ if config.sharding_in_types.value:
+ out_avals = [
+ ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
+ sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))
+ for arg in args
+ ]
+ else:
+ out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype)
+ for arg in args]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
def _check_axis_names(axes):
@@ -673,10 +681,7 @@ def _positional_reduce(aval, arg):
_replica_groups(ctx.module_context.axis_env, named_axes,
axis_index_groups))
axis_context = ctx.module_context.axis_context
- is_spmd = isinstance(
- axis_context,
- (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
- )
+ is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext))
def all_reduce(aval, x):
if is_spmd:
@@ -694,7 +699,11 @@ def all_reduce(aval, x):
else:
op = hlo.AllReduceOp(
[x.type], [x], replica_groups=replica_groups, **other_args)
- scalar_aval = core.ShapedArray((), aval.dtype)
+ if config.sharding_in_types.value:
+ scalar_aval = core.ShapedArray(
+ (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P()))
+ else:
+ scalar_aval = core.ShapedArray((), aval.dtype)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):
@@ -778,7 +787,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
axis_context = ctx.module_context.axis_context
is_manual = (
- isinstance(axis_context, sharding_impls.SPMDAxisContext)
+ isinstance(axis_context, SPMDAxisContext)
and axis_context.manual_axes
)
if is_manual:
@@ -896,7 +905,7 @@ def _all_to_all_lowering(
raise ValueError('Replica groups must be equally sized')
is_spmd = isinstance(
ctx.module_context.axis_context,
- (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
+ (SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
@@ -1129,10 +1138,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
x_aval, = ctx.avals_in
out_aval, = ctx.avals_out
axis_context = ctx.module_context.axis_context
- is_spmd = isinstance(
- axis_context,
- (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
- )
+ is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext))
if not tiled:
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, 1)
@@ -1260,7 +1266,7 @@ def _reduce_scatter_lowering(
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
- (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
+ (SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
@@ -1489,7 +1495,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
- (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
+ (SPMDAxisContext, ShardingContext),
)
if is_spmd:
device_id = hlo.partition_id()
diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py
index 40a04ff11d2c..c6c85ce4f6a3 100644
--- a/jax/_src/lax/slicing.py
+++ b/jax/_src/lax/slicing.py
@@ -1270,6 +1270,39 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
return tuple(core.stride_dim(d, window_size=1, window_stride=s)
for d, s in zip(diff, strides))
+def _get_sub_spec_size(mesh, sub_spec):
+ if isinstance(sub_spec, tuple):
+ return math.prod(mesh.shape[s] for s in sub_spec)
+ return mesh.shape[sub_spec]
+
+def _get_sharding_for_varying_out_shape(out_shape, operand, name):
+ """Returns a sharding when out_shape may not be the same as operand shape"""
+ mesh = operand.sharding.mesh
+ for op_sh, out_sh, op_spec in safe_zip(
+ operand.shape, out_shape, operand.sharding.spec):
+ if (op_sh != out_sh and op_spec is not None and
+ out_sh % _get_sub_spec_size(mesh, op_spec) != 0):
+ raise NotImplementedError(
+ f"{name} on sharded dims where out dim ({out_sh}) is not divisble by"
+ f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec"
+ f" ({op_spec}) is not implemented.")
+ # TODO(yashkatariya): Returning operand.sharding as is may or may not move
+ # data. So think about how to avoid it which might include creating a new
+ # mesh? For example:
+ # mesh = {'x': 4}
+ # x = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))`
+ # ys = lax.split(x, [4, 4]) # This will create outputs of shape (4,)
+ # According to the current logic, ys[0].sharding.spec == P('x')
+ # which involves data movement.
+ return operand.sharding
+
+def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides):
+ # TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
+ # change this logic to `return operand.sharding` directly.
+ out_shape = _slice_shape_rule(operand, start_indices=start_indices,
+ limit_indices=limit_indices, strides=strides)
+ return _get_sharding_for_varying_out_shape(out_shape, operand, 'slicing')
+
def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
assert ad.is_undefined_primal(operand)
operand_shape = operand.aval.shape
@@ -1308,7 +1341,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
return out, bdim
-slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
+slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
+ sharding_rule=_slice_sharding_rule)
ad.deflinear2(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries
@@ -1333,14 +1367,16 @@ def _slice_impl(x, start_indices, limit_indices, strides):
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out
- return [mlir.slice_op(ctx, x, aval_out,
- start_indices=start_indices, limit_indices=limit_indices, strides=strides)]
+ out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
+ limit_indices=limit_indices, strides=strides)
+ if config.sharding_in_types.value:
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
+ return [out]
mlir.register_lowering(slice_p, _slice_lower)
-def _dynamic_slice_shape_rule(
- operand, *starts_and_dyn_sizes, slice_sizes):
+def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
if operand.ndim != len(start_indices):
msg = ("dynamic_slice start_indices must have length equal to the number "
@@ -1363,6 +1399,12 @@ def _dynamic_slice_shape_rule(
f" got indices {start_indices}")
return tuple(lax._merge_dyn_shape(slice_sizes, dyn))
+def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes):
+ out_shape = _dynamic_slice_shape_rule(
+ operand, *starts_and_dyn_sizes, slice_sizes=slice_sizes)
+ return _get_sharding_for_varying_out_shape(out_shape, operand, 'dynamic_slice')
+
+
def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes):
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
if any(i.dtype != start_indices[0].dtype or
@@ -1466,7 +1508,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn,
dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
- weak_type_rule=_argnum_weak_type(0))
+ weak_type_rule=_argnum_weak_type(0),
+ sharding_rule=_dynamic_slice_sharding_rule)
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
@@ -1480,7 +1523,10 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
aval_out, = ctx.avals_out
if dyn:
aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
- return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)]
+ out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)
+ if config.sharding_in_types.value:
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
+ return [out]
mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
@@ -1511,6 +1557,14 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
f"scalars, got indices {start_indices}")
return operand.shape
+def _dynamic_update_slice_sharding_rule(operand, update, *start_indices):
+ if operand.sharding != update.sharding:
+ raise TypeError(
+ "dynamic_update_slice update sharding must be equal to operand"
+ f" sharding, got update sharding {update.sharding} for operand sharding"
+ f" {operand.sharding}.")
+ return operand.sharding
+
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
lax.check_same_dtypes("dynamic_update_slice", operand, update)
if any(i.dtype != start_indices[0].dtype or
@@ -1576,7 +1630,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
dynamic_update_slice_p = standard_primitive(
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
- 'dynamic_update_slice')
+ 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule)
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
ad.primitive_transposes[dynamic_update_slice_p] = \
_dynamic_update_slice_transpose_rule
@@ -1585,8 +1639,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
- return [mlir.dynamic_update_slice(ctx, aval_out, x, update,
- start_indices=start_indices)]
+ out = mlir.dynamic_update_slice(ctx, aval_out, x, update,
+ start_indices=start_indices)
+ if config.sharding_in_types.value:
+ return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
+ return [out]
mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower)
diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py
index 082c443fade4..3d0e1b0cccf5 100644
--- a/jax/_src/mesh.py
+++ b/jax/_src/mesh.py
@@ -107,6 +107,20 @@ class AxisTypes(enum.Enum):
User = enum.auto()
Collective = enum.auto()
+ def __repr__(self):
+ return self.name
+
+def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
+ if axis_types is None:
+ return {}
+ d = {}
+ for t, names in axis_types.items():
+ if isinstance(names, tuple):
+ for n in names:
+ d[n] = t
+ else:
+ d[names] = t
+ return d
_mesh_object_dict = {} # type: ignore
@@ -269,6 +283,10 @@ def shape_tuple(self):
def axis_sizes(self) -> tuple[int, ...]:
return self.devices.shape
+ @functools.cached_property
+ def _name_to_type(self):
+ return axis_names_to_types(self.axis_types)
+
@property
def size(self):
return math.prod(self.shape.values()) if self.devices.ndim else 0
@@ -390,6 +408,10 @@ def axis_names(self):
def axis_sizes(self) -> tuple[int, ...]:
return self._axis_sizes
+ @functools.cached_property
+ def _name_to_type(self):
+ return axis_names_to_types(self.axis_types)
+
@functools.cached_property
def size(self):
return math.prod(self._axis_sizes) if self._axis_sizes else 0
@@ -407,7 +429,7 @@ def empty(self):
return self.size == 0
@functools.cached_property
- def are_all_axes_collective(self) -> bool:
+ def _are_all_axes_collective(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
@@ -433,14 +455,22 @@ def local_mesh(self):
_raise_value_error("local_mesh")
def __enter__(self):
- raise RuntimeError("AbstractMesh is not a context manager")
+ mesh_context.stack.append(self)
+ mesh_context.mesh = self
+ jax_config.abstract_mesh_context_manager.set_local(
+ tuple(m for m in mesh_context.stack if m is not None))
+ return self
def __exit__(self, exc_type, exc_value, traceback):
- raise RuntimeError("AbstractMesh is not a context manager")
+ mesh_context.stack.pop()
+ mesh_context.mesh = mesh_context.stack[-1]
+ jax_config.abstract_mesh_context_manager.set_local(
+ tuple(m for m in mesh_context.stack if m is not None))
+ return False
@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
- jax_config.mesh_context_manager.set_local(mesh)
+ jax_config.abstract_mesh_context_manager.set_local(mesh)
return
@@ -448,3 +478,11 @@ def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
# property raises an exception unconditionally. Remove this once that is fixed.
def _raise_value_error(name):
raise ValueError(f"AbstractMesh does not implement {name}")
+
+
+class MeshContext(threading.local):
+ def __init__(self):
+ self.stack = [None]
+ self.mesh = self.stack[-1]
+
+mesh_context = MeshContext()
diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py
index eb1bb1609bbf..8086a97a3748 100644
--- a/jax/_src/nn/initializers.py
+++ b/jax/_src/nn/initializers.py
@@ -36,7 +36,6 @@
export = set_module('jax.nn.initializers')
-KeyArray = Array
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
@@ -48,13 +47,13 @@
@typing.runtime_checkable
class Initializer(Protocol):
@staticmethod
- def __call__(key: KeyArray,
+ def __call__(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
raise NotImplementedError
@export
-def zeros(key: KeyArray,
+def zeros(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
"""An initializer that returns a constant array full of zeros.
@@ -69,7 +68,7 @@ def zeros(key: KeyArray,
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
@export
-def ones(key: KeyArray,
+def ones(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
"""An initializer that returns a constant array full of ones.
@@ -100,7 +99,7 @@ def constant(value: ArrayLike,
Array([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
@@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2,
Array([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
@@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2,
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
@@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
@@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int],
fan_out = out_size * receptive_field_size
return fan_in, fan_out
-def _complex_uniform(key: KeyArray,
+def _complex_uniform(key: Array,
shape: Sequence[int],
dtype: DTypeLikeInexact) -> Array:
"""
@@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray,
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)
-def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
+def _complex_truncated_normal(key: Array, upper: ArrayLike,
shape: Sequence[int],
dtype: DTypeLikeInexact) -> Array:
"""
@@ -314,7 +313,7 @@ def variance_scaling(
dtype: the dtype of the weights.
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
shape = core.canonicalize_shape(shape)
@@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0,
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
@@ -654,7 +653,7 @@ def delta_orthogonal(
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
"""
- def init(key: KeyArray,
+ def init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py
index 90a17000cf16..ec67d7489f30 100644
--- a/jax/_src/numpy/index_tricks.py
+++ b/jax/_src/numpy/index_tricks.py
@@ -24,10 +24,14 @@
arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose
)
from jax._src.typing import Array, ArrayLike
+from jax._src.util import set_module
import numpy as np
+export = set_module('jax.numpy')
+
+
__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"]
@@ -87,7 +91,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array:
return stack(output_arr, 0)
-mgrid = _Mgrid()
+mgrid = export(_Mgrid())
class _Ogrid:
@@ -129,7 +133,7 @@ def __getitem__(
return meshgrid(*output, indexing='ij', sparse=True)
-ogrid = _Ogrid()
+ogrid = export(_Ogrid())
_IndexType = Union[ArrayLike, str, slice]
@@ -279,7 +283,7 @@ class RClass(_AxisConcat):
op_name = "r_"
-r_ = RClass()
+r_ = export(RClass())
class CClass(_AxisConcat):
@@ -327,7 +331,7 @@ class CClass(_AxisConcat):
op_name = "c_"
-c_ = CClass()
+c_ = export(CClass())
s_ = np.s_
diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py
index 88ddc85a0a40..5f380fad902c 100644
--- a/jax/_src/numpy/lax_numpy.py
+++ b/jax/_src/numpy/lax_numpy.py
@@ -68,13 +68,16 @@
)
from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
- ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
+ ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2,
+ tuple_replace)
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np
import opt_einsum
+export = set_module('jax.numpy')
+
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
try:
cuda_plugin_extension = importlib.import_module(
@@ -116,6 +119,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
printoptions = np.printoptions
set_printoptions = np.set_printoptions
+@export
def iscomplexobj(x: Any) -> bool:
"""Check if the input is a complex number or an array containing complex elements.
@@ -217,6 +221,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
+if dtypes.float8_e3m4 is not None:
+ float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
+if dtypes.float8_e4m3 is not None:
+ float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
@@ -327,6 +335,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
return clip(val, min_val, max_val).astype(dtype)
+@export
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
"""Load JAX arrays from npy files.
@@ -376,6 +385,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) ->
### implementations of numpy functions in terms of lax
+@export
@jit
def fmin(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Return element-wise minimum of the input arrays.
@@ -427,6 +437,7 @@ def fmin(x1: ArrayLike, x2: ArrayLike) -> Array:
return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2)
+@export
@jit
def fmax(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Return element-wise maximum of the input arrays.
@@ -476,6 +487,7 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array:
return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2)
+@export
def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
"""Return True if arg1 is equal or lower than arg2 in the type hierarchy.
@@ -522,6 +534,7 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
return dtypes.issubdtype(arg1, arg2)
+@export
def isscalar(element: Any) -> bool:
"""Return True if the input is a scalar.
@@ -620,6 +633,7 @@ def isscalar(element: Any) -> bool:
iterable = np.iterable
+@export
def result_type(*args: Any) -> DType:
"""Return the result of applying JAX promotion rules to the inputs.
@@ -663,6 +677,7 @@ def result_type(*args: Any) -> DType:
return dtypes.result_type(*args)
+@export
@jit
def trunc(x: ArrayLike) -> Array:
"""Round input to the nearest integer towards zero.
@@ -739,6 +754,7 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike,
return result[0, 0, out_order]
+@export
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
precision: PrecisionLike = None,
@@ -814,6 +830,7 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
precision=precision, preferred_element_type=preferred_element_type)
+@export
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
precision: PrecisionLike = None,
@@ -899,6 +916,7 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
precision=precision, preferred_element_type=preferred_element_type)
+@export
def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
range: None | Array | Sequence[ArrayLike] = None,
weights: ArrayLike | None = None) -> Array:
@@ -950,6 +968,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
return linspace(range[0], range[1], bins_int + 1, dtype=dtype)
+@export
def histogram(a: ArrayLike, bins: ArrayLike = 10,
range: Sequence[ArrayLike] | None = None,
weights: ArrayLike | None = None,
@@ -1031,6 +1050,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10,
return counts, bin_edges
+@export
def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
range: Sequence[None | Array | Sequence[ArrayLike]] | None = None,
weights: ArrayLike | None = None,
@@ -1120,6 +1140,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] =
return hist, edges[0], edges[1]
+@export
def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
range: Sequence[None | Array | Sequence[ArrayLike]] | None = None,
weights: ArrayLike | None = None,
@@ -1229,6 +1250,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
return hist, bin_edges_by_dim
+@export
def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
"""Return a transposed version of an N-dimensional array.
@@ -1307,6 +1329,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
return lax.transpose(a, axes_)
+@export
def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array:
"""Permute the axes/dimensions of an array.
@@ -1336,6 +1359,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array:
return lax.transpose(a, axes)
+@export
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transpose the last two dimensions of an array.
@@ -1389,6 +1413,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
return lax.transpose(x, axes)
+@export
@partial(jit, static_argnames=('k', 'axes'))
def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
"""Rotate an array by 90 degrees counterclockwise in the plane specified by axes.
@@ -1472,6 +1497,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
return flip(transpose(m, perm), ax2)
+@export
def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
"""Reverse the order of elements of an array along the given axis.
@@ -1539,6 +1565,7 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
+@export
def fliplr(m: ArrayLike) -> Array:
"""Reverse the order of elements of an array along axis 1.
@@ -1565,6 +1592,7 @@ def fliplr(m: ArrayLike) -> Array:
return _flip(asarray(m), 1)
+@export
def flipud(m: ArrayLike) -> Array:
"""Reverse the order of elements of an array along axis 0.
@@ -1590,6 +1618,8 @@ def flipud(m: ArrayLike) -> Array:
util.check_arraylike("flipud", m)
return _flip(asarray(m), 0)
+
+@export
@jit
def iscomplex(x: ArrayLike) -> Array:
"""Return boolean array showing where the input is complex.
@@ -1613,6 +1643,8 @@ def iscomplex(x: ArrayLike) -> Array:
i = ufuncs.imag(x)
return lax.ne(i, _lax_const(i, 0))
+
+@export
@jit
def isreal(x: ArrayLike) -> Array:
"""Return boolean array showing where the input is real.
@@ -1637,6 +1669,7 @@ def isreal(x: ArrayLike) -> Array:
return lax.eq(i, _lax_const(i, 0))
+@export
@partial(jit, static_argnames=['deg'])
def angle(z: ArrayLike, deg: bool = False) -> Array:
"""Return the angle of a complex valued number or array.
@@ -1688,6 +1721,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array:
return ufuncs.degrees(result) if deg else result
+@export
@partial(jit, static_argnames=('n', 'axis'))
def diff(a: ArrayLike, n: int = 1, axis: int = -1,
prepend: ArrayLike | None = None,
@@ -1800,6 +1834,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
return arr
+@export
@jit
def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None,
to_begin: ArrayLike | None = None) -> Array:
@@ -1862,6 +1897,8 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None,
result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype))))
return result
+
+@export
@partial(jit, static_argnames=("axis", "edge_order"))
def gradient(
f: ArrayLike,
@@ -1992,6 +2029,7 @@ def gradient_along_axis(a, h, axis):
return a_grad[0] if len(axis_tuple) == 1 else a_grad
+@export
def isrealobj(x: Any) -> bool:
"""Check if the input is not a complex number or an array containing complex elements.
@@ -2026,6 +2064,7 @@ def isrealobj(x: Any) -> bool:
return not iscomplexobj(x)
+@export
def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(),
@@ -2129,6 +2168,7 @@ def reshape(
return asarray(a).reshape(shape, order=order)
+@export
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a: ArrayLike, order: str = "C") -> Array:
"""Flatten array into a 1-dimensional shape.
@@ -2182,6 +2222,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
return reshape(a, (size(a),), order)
+@export
def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
mode: str = 'raise', order: str = 'C') -> Array:
"""Convert multi-dimensional indices into flat indices.
@@ -2273,6 +2314,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
return result
+@export
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
"""Convert flat indices into multi-dimensional indices.
@@ -2336,6 +2378,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
for s, i in safe_zip(shape, out_indices))
+@export
@partial(jit, static_argnames=('new_shape',))
def resize(a: ArrayLike, new_shape: Shape) -> Array:
"""Return a new array with specified shape.
@@ -2387,6 +2430,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array:
return reshape(arr, new_shape)
+@export
def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
"""Remove one or more length-1 axes from array
@@ -2457,6 +2501,7 @@ def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
return lax.squeeze(a, axis)
+@export
def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array:
"""Insert dimensions of length 1 into array
@@ -2527,6 +2572,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array:
return lax.expand_dims(a, axis)
+@export
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
"""Swap two axes of an array.
@@ -2574,6 +2620,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
return lax.transpose(a, list(perm))
+@export
def moveaxis(a: ArrayLike, source: int | Sequence[int],
destination: int | Sequence[int]) -> Array:
"""Move an array axis to a new position
@@ -2639,6 +2686,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -
return lax.transpose(a, perm)
+@export
@partial(jit, static_argnames=('equal_nan',))
def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08,
equal_nan: bool = False) -> Array:
@@ -2783,6 +2831,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
return f
+@export
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: ArrayLike | str | None = None,
right: ArrayLike | str | None = None,
@@ -2865,6 +2914,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
) -> Array | tuple[Array, ...]: ...
+@export
def where(condition, x=None, y=None, /, *, size=None, fill_value=None):
"""Select elements from two arrays based on a condition.
@@ -2940,6 +2990,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None):
return util._where(condition, x, y)
+@export
def select(
condlist: Sequence[ArrayLike],
choicelist: Sequence[ArrayLike],
@@ -3007,6 +3058,7 @@ def select(
return lax.select_n(*broadcast_arrays(idx, *choicelist))
+@export
def bincount(x: ArrayLike, weights: ArrayLike | None = None,
minlength: int = 0, *, length: int | None = None
) -> Array:
@@ -3099,6 +3151,7 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ...
def broadcast_shapes(*shapes: Sequence[int | core.Tracer]
) -> tuple[int | core.Tracer, ...]: ...
+@export
def broadcast_shapes(*shapes):
"""Broadcast input shapes to a common output shape.
@@ -3139,6 +3192,7 @@ def broadcast_shapes(*shapes):
return lax.broadcast_shapes(*shapes)
+@export
def broadcast_arrays(*args: ArrayLike) -> list[Array]:
"""Broadcast arrays to a common shape.
@@ -3178,6 +3232,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]:
return util._broadcast_arrays(*args)
+@export
def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
"""Broadcast an array to a specified shape.
@@ -3254,6 +3309,7 @@ def _split(op: str, ary: ArrayLike,
for start, end in zip(split_indices[:-1], split_indices[1:])]
+@export
def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
axis: int = 0) -> list[Array]:
"""Split an array into sub-arrays.
@@ -3317,6 +3373,7 @@ def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
return _split("split", ary, indices_or_sections, axis=axis)
+@export
def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
"""Split an array into sub-arrays vertically.
@@ -3351,6 +3408,7 @@ def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike)
return _split("vsplit", ary, indices_or_sections, axis=0)
+@export
def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
"""Split an array into sub-arrays horizontally.
@@ -3391,6 +3449,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike)
return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1)
+@export
def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]:
"""Split an array into sub-arrays depth-wise.
@@ -3432,6 +3491,7 @@ def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike)
return _split("dsplit", ary, indices_or_sections, axis=2)
+@export
def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,
axis: int = 0) -> list[Array]:
"""Split an array into sub-arrays.
@@ -3457,6 +3517,7 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
return _split("array_split", ary, indices_or_sections, axis=axis)
+@export
@jit
def clip(
arr: ArrayLike | None = None,
@@ -3528,6 +3589,7 @@ def clip(
return asarray(arr)
+@export
@partial(jit, static_argnames=('decimals',))
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
"""Round input evenly to the given number of decimals.
@@ -3599,12 +3661,14 @@ def _round_float(x: ArrayLike) -> Array:
return _round_float(a)
+@export
@partial(jit, static_argnames=('decimals',))
def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
"""Alias of :func:`jax.numpy.round`"""
return round(a, decimals, out)
+@export
@jit
def fix(x: ArrayLike, out: None = None) -> Array:
"""Round input to the nearest integer towards zero.
@@ -3643,6 +3707,7 @@ def fix(x: ArrayLike, out: None = None) -> Array:
return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x))
+@export
@jit
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
posinf: ArrayLike | None = None,
@@ -3708,6 +3773,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
return out
+@export
@partial(jit, static_argnames=('equal_nan',))
def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array:
@@ -3756,6 +3822,7 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
return reductions.all(isclose(a, b, rtol, atol, equal_nan))
+@export
def nonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]:
@@ -3863,6 +3930,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
return out
+@export
def flatnonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array:
"""Return indices of nonzero elements in a flattened array
@@ -3908,6 +3976,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None,
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
+@export
@partial(jit, static_argnames=('axis',))
def unwrap(p: ArrayLike, discont: ArrayLike | None = None,
axis: int = -1, period: ArrayLike = 2 * pi) -> Array:
@@ -4337,6 +4406,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str,
"not implemented modes")
+@export
def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray],
mode: str | Callable[..., Any] = "constant", **kwargs) -> Array:
"""Add padding to an array.
@@ -4493,6 +4563,7 @@ def pad_func(row: Array, pad_width: tuple[int, int],
### Array-creation functions
+@export
def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array:
"""Join arrays along a new axis.
@@ -4559,6 +4630,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike],
return concatenate(new_arrays, axis=axis, dtype=dtype)
+@export
@partial(jit, static_argnames="axis")
def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
"""Unstack an array along an axis.
@@ -4599,6 +4671,8 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
)
return tuple(moveaxis(x, axis, 0))
+
+@export
def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array:
"""Construct an array by repeating ``A`` along specified dimensions.
@@ -4662,6 +4736,7 @@ def _concatenate_array(arr: ArrayLike, axis: int | None,
return lax.reshape(arr, shape, dimensions)
+@export
def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
axis: int | None = 0, dtype: DTypeLike | None = None) -> Array:
"""Join arrays along an existing axis.
@@ -4725,6 +4800,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
return arrays_out[0]
+@export
def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array:
"""Join arrays along an existing axis.
@@ -4765,6 +4841,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array:
return jax.numpy.concatenate(arrays, axis=axis)
+@export
def vstack(tup: np.ndarray | Array | Sequence[ArrayLike],
dtype: DTypeLike | None = None) -> Array:
"""Vertically stack arrays.
@@ -4825,6 +4902,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike],
return concatenate(arrs, axis=0, dtype=dtype)
+@export
def hstack(tup: np.ndarray | Array | Sequence[ArrayLike],
dtype: DTypeLike | None = None) -> Array:
"""Horizontally stack arrays.
@@ -4885,6 +4963,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike],
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)
+@export
def dstack(tup: np.ndarray | Array | Sequence[ArrayLike],
dtype: DTypeLike | None = None) -> Array:
"""Stack arrays depth-wise.
@@ -4945,6 +5024,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike],
return concatenate(arrs, axis=2, dtype=dtype)
+@export
def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array:
"""Stack arrays column-wise.
@@ -5005,6 +5085,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array:
return concatenate(arrs, axis=1)
+@export
def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
"""Construct an array by stacking slices of choice arrays.
@@ -5129,6 +5210,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]:
return asarray(xs), 1
+@export
@jit
def block(arrays: ArrayLike | list[ArrayLike]) -> Array:
"""Create an array from a list of blocks.
@@ -5212,6 +5294,7 @@ def atleast_1d(x: ArrayLike, /) -> Array:
@overload
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...
+@export
@jit
def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
"""Convert inputs to arrays with at least 1 dimension.
@@ -5266,6 +5349,7 @@ def atleast_2d(x: ArrayLike, /) -> Array:
@overload
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...
+@export
@jit
def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
"""Convert inputs to arrays with at least 2 dimensions.
@@ -5329,6 +5413,7 @@ def atleast_3d(x: ArrayLike, /) -> Array:
@overload
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...
+@export
@jit
def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
"""Convert inputs to arrays with at least 3 dimensions.
@@ -5405,6 +5490,7 @@ def _supports_buffer_protocol(obj):
return True
+@export
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
@@ -5597,6 +5683,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
return x
+@export
def astype(x: ArrayLike, dtype: DTypeLike | None,
/, *, copy: bool = False,
device: xc.Device | Sharding | None = None) -> Array:
@@ -5662,6 +5749,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
return _array_copy(result) if copy else result
+@export
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
*, copy: bool | None = None,
device: xc.Device | Sharding | None = None) -> Array:
@@ -5743,6 +5831,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
+@export
def copy(a: ArrayLike, order: str | None = None) -> Array:
"""Return a copy of the array.
@@ -5791,6 +5880,7 @@ def copy(a: ArrayLike, order: str | None = None) -> Array:
return array(a, copy=True, order=order)
+@export
def zeros_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None, *,
@@ -5833,6 +5923,7 @@ def zeros_like(a: ArrayLike | DuckTypedArray,
return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device))
+@export
def ones_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None, *,
@@ -5875,6 +5966,7 @@ def ones_like(a: ArrayLike | DuckTypedArray,
return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device))
+@export
def empty_like(prototype: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = None,
shape: Any = None, *,
@@ -5924,6 +6016,7 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No
return device
+@export
def full(shape: Any, fill_value: ArrayLike,
dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
@@ -5972,6 +6065,7 @@ def full(shape: Any, fill_value: ArrayLike,
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
+@export
def full_like(a: ArrayLike | DuckTypedArray,
fill_value: ArrayLike, dtype: DTypeLike | None = None,
shape: Any = None, *,
@@ -6028,6 +6122,7 @@ def full_like(a: ArrayLike | DuckTypedArray,
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
+@export
def zeros(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of zeros.
@@ -6064,6 +6159,7 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *,
return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
+@export
def ones(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an array full of ones.
@@ -6100,6 +6196,7 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *,
return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
+@export
def empty(shape: Any, dtype: DTypeLike | None = None, *,
device: xc.Device | Sharding | None = None) -> Array:
"""Create an empty array.
@@ -6143,6 +6240,7 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
"with a single tuple argument for the shape?")
+@export
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
"""Check if two arrays are element-wise equal.
@@ -6184,6 +6282,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
return reductions.all(eq)
+@export
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
"""Check if two arrays are element-wise equal.
@@ -6224,6 +6323,7 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
# General np.from* style functions mostly delegate to numpy.
+@export
def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float,
count: int = -1, offset: int = 0) -> Array:
r"""Convert a buffer into a 1-D JAX array.
@@ -6271,6 +6371,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float,
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
+@export
def fromfile(*args, **kwargs):
"""Unimplemented JAX wrapper for jnp.fromfile.
@@ -6289,6 +6390,7 @@ def fromfile(*args, **kwargs):
"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")
+@export
def fromiter(*args, **kwargs):
"""Unimplemented JAX wrapper for jnp.fromiter.
@@ -6307,6 +6409,7 @@ def fromiter(*args, **kwargs):
"https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions")
+@export
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
copy: bool | None = None) -> Array:
"""Construct a JAX array via DLPack.
@@ -6367,6 +6470,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
return from_dlpack(x, device=device, copy=copy)
+@export
def fromfunction(function: Callable[..., Array], shape: Any,
*, dtype: DTypeLike = float, **kwargs) -> Array:
"""Create an array from a function applied over indices.
@@ -6453,6 +6557,7 @@ def fromfunction(function: Callable[..., Array], shape: Any,
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
+@export
def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array:
"""Convert a string of text into 1-D JAX array.
@@ -6481,6 +6586,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
+@export
def eye(N: DimSize, M: DimSize | None = None,
k: int | ArrayLike = 0,
dtype: DTypeLike | None = None,
@@ -6560,6 +6666,7 @@ def _eye(N: DimSize, M: DimSize | None = None,
return (i + offset == j).astype(dtype)
+@export
def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
"""Create a square identity matrix
@@ -6593,6 +6700,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
return eye(n, dtype=dtype)
+@export
def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
step: ArrayLike | None = None, dtype: DTypeLike | None = None,
*, device: xc.Device | Sharding | None = None) -> Array:
@@ -6760,6 +6868,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
dtype: DTypeLike | None = None,
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
+@export
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
@@ -6885,6 +6994,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
return (result, delta) if retstep else result
+@export
def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, base: ArrayLike = 10.0,
dtype: DTypeLike | None = None, axis: int = 0) -> Array:
@@ -6970,6 +7080,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
return lax.convert_element_type(ufuncs.power(base, lin), dtype)
+@export
def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True,
dtype: DTypeLike | None = None, axis: int = 0) -> Array:
"""Generate geometrically-spaced values.
@@ -7044,6 +7155,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
return lax.convert_element_type(res, dtype)
+@export
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> list[Array]:
"""Construct N-dimensional grid arrays from N 1-dimensional vectors.
@@ -7125,6 +7237,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
return output
+@export
@jit
def i0(x: ArrayLike) -> Array:
r"""Calculate modified Bessel function of first kind, zeroth order.
@@ -7174,6 +7287,7 @@ def _i0_jvp(primals, tangents):
primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents)
return primal_out, where(primals[0] == 0, 0.0, tangent_out)
+@export
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
"""Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.
@@ -7237,6 +7351,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: bool = False) -> Array | tuple[Array, ...]: ...
+@export
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: bool = False) -> Array | tuple[Array, ...]:
"""Generate arrays of grid indices.
@@ -7287,6 +7402,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
return stack(output, 0) if output else array([], dtype=dtype)
+@export
def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
total_repeat_length: int | None = None) -> Array:
"""Construct an array from repeated elements.
@@ -7431,6 +7547,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
return take(a, gather_indices, axis=axis)
+@export
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
@@ -7490,6 +7607,7 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
+@export
def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array:
r"""Return an array with ones on and below the diagonal and zeros elsewhere.
@@ -7546,6 +7664,7 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None
return lax_internal._tri(dtype, (N, M), k)
+@export
@partial(jit, static_argnames=('k',))
def tril(m: ArrayLike, k: int = 0) -> Array:
r"""Return lower triangle of an array.
@@ -7607,6 +7726,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
+@export
@partial(jit, static_argnames=('k',))
def triu(m: ArrayLike, k: int = 0) -> Array:
r"""Return upper triangle of an array.
@@ -7672,6 +7792,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
+@export
@partial(jit, static_argnames=('axis1', 'axis2', 'dtype'))
def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@@ -7737,6 +7858,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int
return reductions.sum(a, axis=(-2, -1), dtype=dtype)
+@export
def mask_indices(n: int,
mask_func: Callable[[ArrayLike, int], Array],
k: int = 0, *, size: int | None = None) -> tuple[Array, Array]:
@@ -7796,6 +7918,7 @@ def _triu_size(n, m, k):
return mk * (mk + 1) // 2 + mk * (m - k - mk)
+@export
def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]:
"""Return the indices of upper triangle of an array of size ``(n, m)``.
@@ -7854,6 +7977,7 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array
return i, j
+@export
def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]:
"""Return the indices of lower triangle of an array of size ``(n, m)``.
@@ -7912,6 +8036,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array
return i, j
+@export
def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
"""Return the indices of upper triangle of a given array.
@@ -7969,6 +8094,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
return triu_indices(arr_shape[0], k=k, m=arr_shape[1])
+@export
def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
"""Return the indices of lower triangle of a given array.
@@ -8026,6 +8152,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])
+@export
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *,
inplace: bool = True) -> Array:
"""Return a copy of the array with the diagonal overwritten.
@@ -8107,6 +8234,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *,
return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n))
+@export
def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
"""Return indices for accessing the main diagonal of a multidimensional array.
@@ -8142,6 +8270,8 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]:
.format(ndim))
return (lax.iota(int_, n),) * ndim
+
+@export
def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]:
"""Return indices for accessing the main diagonal of a given array.
@@ -8183,6 +8313,8 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]:
return diag_indices(s[0], ndim=nd)
+
+@export
@partial(jit, static_argnames=('offset', 'axis1', 'axis2'))
def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
axis2: int = 1) -> Array:
@@ -8234,6 +8366,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
return a[..., i, j] if offset >= 0 else a[..., j, i]
+@export
def diag(v: ArrayLike, k: int = 0) -> Array:
"""Returns the specified diagonal or constructs a diagonal array.
@@ -8297,6 +8430,8 @@ def _diag(v, k):
else:
raise ValueError("diag input must be 1d or 2d")
+
+@export
def diagflat(v: ArrayLike, k: int = 0) -> Array:
"""Return a 2-D array with the flattened input array laid out on the diagonal.
@@ -8353,6 +8488,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array:
# TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2
+@export
def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array:
"""Trim leading and/or trailing zeros of the input array.
@@ -8407,6 +8543,8 @@ def trim_zeros_tol(filt, tol, trim='fb'):
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
+
+@export
@partial(jit, static_argnames=('axis',))
def append(
arr: ArrayLike, values: ArrayLike, axis: int | None = None
@@ -8461,6 +8599,7 @@ def append(
return concatenate([arr, values], axis=axis)
+@export
def delete(
arr: ArrayLike,
obj: ArrayLike | slice,
@@ -8585,6 +8724,7 @@ def delete(
return a[tuple(slice(None) for i in range(axis)) + (mask,)]
+@export
def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
axis: int | None = None) -> Array:
"""Insert entries into an array at specified indices.
@@ -8684,6 +8824,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike,
return out
+@export
def apply_along_axis(
func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs
) -> Array:
@@ -8761,6 +8902,7 @@ def apply_along_axis(
return func(arr)
+@export
def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
axes: Sequence[int]) -> Array:
"""Apply a function repeatedly over specified axes.
@@ -8819,6 +8961,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike,
### Tensor contraction operations
+@export
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def dot(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
@@ -8908,6 +9051,7 @@ def dot(a: ArrayLike, b: ArrayLike, *,
output_weak_type)
+@export
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def matmul(a: ArrayLike, b: ArrayLike, *,
precision: PrecisionLike = None,
@@ -9031,6 +9175,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
+@export
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def vdot(
a: ArrayLike, b: ArrayLike, *,
@@ -9079,6 +9224,7 @@ def vdot(
preferred_element_type=preferred_element_type)
+@export
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
@@ -9134,6 +9280,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
signature="(n),(n)->()")(x1_arr, x2_arr)
+@export
def tensordot(a: ArrayLike, b: ArrayLike,
axes: int | Sequence[int] | Sequence[Sequence[int]] = 2,
*, precision: PrecisionLike = None,
@@ -9279,6 +9426,7 @@ def einsum(
out_type=None,
) -> Array: ...
+@export
def einsum(
subscripts, /,
*operands,
@@ -9554,6 +9702,7 @@ def einsum_path(
optimize: bool | str | list[tuple[int, ...]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...
+@export
def einsum_path(
subscripts, /,
*operands,
@@ -9787,6 +9936,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
output_weak_type)
+@export
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def inner(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None,
@@ -9843,6 +9993,7 @@ def inner(
preferred_element_type=preferred_element_type)
+@export
@partial(jit, inline=True)
def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
"""Compute the outer product of two arrays.
@@ -9877,6 +10028,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
return ravel(a)[:, None] * ravel(b)[None, :]
+@export
@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis'))
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
axis: int | None = None):
@@ -9977,6 +10129,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
return moveaxis(c, 0, axisc)
+@export
@jit
def kron(a: ArrayLike, b: ArrayLike) -> Array:
"""Compute the Kronecker product of two input arrays.
@@ -10022,6 +10175,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array:
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
+@export
@partial(jit, static_argnames=('N', 'increasing'))
def vander(
x: ArrayLike, N: int | None = None, increasing: bool = False
@@ -10085,6 +10239,7 @@ def vander(
### Misc
+@export
def argwhere(
a: ArrayLike,
*,
@@ -10150,6 +10305,7 @@ def argwhere(
return result.reshape(result.shape[0], ndim(a))
+@export
def argmax(a: ArrayLike, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array:
"""Return the index of the maximum value of an array.
@@ -10205,6 +10361,7 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
return expand_dims(result, dims) if keepdims else result
+@export
def argmin(a: ArrayLike, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array:
"""Return the index of the minimum value of an array.
@@ -10260,6 +10417,7 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array:
return expand_dims(result, dims) if keepdims else result
+@export
def nanargmax(
a: ArrayLike,
axis: int | None = None,
@@ -10327,6 +10485,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False):
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
+@export
def nanargmin(
a: ArrayLike,
axis: int | None = None,
@@ -10387,6 +10546,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
+@export
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def sort(
a: ArrayLike,
@@ -10450,6 +10610,7 @@ def sort(
return lax.rev(result, dimensions=[dimension]) if descending else result
+@export
@jit
def sort_complex(a: ArrayLike) -> Array:
"""Return a sorted copy of complex array.
@@ -10487,6 +10648,7 @@ def sort_complex(a: ArrayLike) -> Array:
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
+@export
@partial(jit, static_argnames=('axis',))
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
"""Sort a sequence of keys in lexicographic order.
@@ -10564,6 +10726,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A
return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1]
+@export
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def argsort(
a: ArrayLike,
@@ -10644,6 +10807,7 @@ def argsort(
return lax.rev(indices, dimensions=[dimension]) if descending else indices
+@export
@partial(jit, static_argnames=['kth', 'axis'])
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
"""Returns a partially-sorted copy of an array.
@@ -10714,6 +10878,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
return swapaxes(out, -1, axis)
+@export
@partial(jit, static_argnames=['kth', 'axis'])
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
"""Returns indices that partially sort an array.
@@ -10818,6 +10983,8 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array:
dimension=ax)
return a
+
+@export
def roll(a: ArrayLike, shift: ArrayLike | Sequence[int],
axis: int | Sequence[int] | None = None) -> Array:
"""Roll the elements of an array along a specified axis.
@@ -10871,6 +11038,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int],
return _roll_static(arr, shift, axis)
+@export
@partial(jit, static_argnames=('axis', 'start'))
def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array:
"""Roll the specified axis to a given position.
@@ -10936,6 +11104,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array:
return moveaxis(a, axis, start)
+@export
@partial(jit, static_argnames=('axis', 'bitorder'))
def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array:
"""Pack array of bits into a uint8 array.
@@ -11020,6 +11189,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar
return swapaxes(packed, axis, -1)
+@export
@partial(jit, static_argnames=('axis', 'count', 'bitorder'))
def unpackbits(
a: ArrayLike,
@@ -11111,6 +11281,7 @@ def unpackbits(
return swapaxes(unpacked, axis, -1)
+@export
def take(
a: ArrayLike,
indices: ArrayLike,
@@ -11268,6 +11439,7 @@ def _normalize_index(index, axis_size):
return lax.select(index < 0, lax.add(index, axis_size_val), index)
+@export
@partial(jit, static_argnames=('axis', 'mode', 'fill_value'))
def take_along_axis(
arr: ArrayLike,
@@ -11455,6 +11627,106 @@ def replace(tup, val):
mode="fill" if mode is None else mode, fill_value=fill_value)
+_indices = indices # argument below named 'indices' shadows the function
+
+
+def _make_along_axis_idx(shape, indices, axis):
+ return tuple_replace(_indices(shape, sparse=True), axis, indices)
+
+
+@export
+@partial(jit, static_argnames=('axis', 'inplace', 'mode'))
+def put_along_axis(
+ arr: ArrayLike,
+ indices: ArrayLike,
+ values: ArrayLike,
+ axis: int | None,
+ inplace: bool = True,
+ *,
+ mode: str | None = None,
+) -> Array:
+ """Put values into the destination array by matching 1d index and data slices.
+
+ JAX implementation of :func:`numpy.put_along_axis`.
+
+ The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
+ is not possible for JAX's immutable arrays. The JAX version returns a modified
+ copy of the input, and adds the ``inplace`` parameter which must be set to
+ `False`` by the user as a reminder of this API difference.
+
+ Args:
+ arr: array into which values will be put.
+ indices: array of indices at which to put values.
+ values: array of values to put into the array.
+ axis: the axis along which to put values. If not specified, the array will
+ be flattened before indexing is applied.
+ inplace: must be set to False to indicate that the input is not modified
+ in-place, but rather a modified copy is returned.
+ mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
+ see :attr:`jax.numpy.ndarray.at`.
+
+ Returns:
+ A copy of ``a`` with specified entries updated.
+
+ See Also:
+ - :func:`jax.numpy.put`: put elements into an array at given indices.
+ - :func:`jax.numpy.place`: place elements into an array via boolean mask.
+ - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
+ - :func:`jax.numpy.take`: extract values from an array at given indices.
+ - :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
+
+ Examples:
+ >>> from jax import numpy as jnp
+ >>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
+ >>> i = jnp.argmax(a, axis=1, keepdims=True)
+ >>> print(i)
+ [[1]
+ [0]]
+ >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
+ >>> print(b)
+ [[10 99 20]
+ [99 40 50]]
+ """
+ if inplace:
+ raise ValueError(
+ "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
+ "are immutable. Pass inplace=False to instead return an updated array.")
+
+ util.check_arraylike("put_along_axis", arr, indices, values)
+ arr = asarray(arr)
+ indices = asarray(indices)
+ values = asarray(values)
+
+ original_axis = axis
+ original_arr_shape = arr.shape
+
+ if axis is None:
+ arr = arr.ravel()
+ axis = 0
+
+ if not arr.ndim == indices.ndim:
+ raise ValueError(
+ "put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
+ f"{arr.ndim=} and {indices.ndim=}."
+ )
+
+ try:
+ values = broadcast_to(values, indices.shape)
+ except ValueError:
+ raise ValueError(
+ "put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
+ f"{values.shape=} and {indices.shape=}."
+ )
+
+ idx = _make_along_axis_idx(arr.shape, indices, axis)
+ result = arr.at[idx].set(values, mode=mode)
+
+ if original_axis is None:
+ result = result.reshape(original_arr_shape)
+
+ return result
+
+
### Indexing
def _is_integer_index(idx: Any) -> bool:
@@ -11844,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
"arrays within JIT compiled functions).")
raise IndexError(msg)
- start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
+ start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
slice_shape.append(slice_size)
if core.definitely_equal(step, 1):
@@ -12047,66 +12319,8 @@ def _canonicalize_tuple_index(arr_ndim, idx):
idx = tuple(idx) + colons
return idx
-def _preprocess_slice(
- s: slice,
- axis_size: core.DimSize
- ) -> tuple[core.DimSize, core.DimSize, core.DimSize]:
- """Computes the start index, step, and size of the slice `x[s]`."""
- # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
- # "this is harder to get right than you may think"
- # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)
- def convert_to_index(d: DimSize) -> DimSize:
- # Convert np.array and jax.Array to int, leave symbolic dimensions alone
- try:
- return operator.index(d)
- except:
- return d
-
- # Must resolve statically if step is {<0, ==0, >0}
- step = convert_to_index(s.step) if s.step is not None else 1
- try:
- if step == 0:
- raise ValueError("slice step cannot be zero")
- step_gt_0 = (step > 0)
- except core.InconclusiveDimensionOperation as e:
- raise core.InconclusiveDimensionOperation(
- f"In slice with non-constant elements the step ({step}) must " +
- f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
-
- def clamp_index(i: DimSize, which: str):
- try:
- i_ge_0 = (i >= 0)
- except core.InconclusiveDimensionOperation as e:
- raise core.InconclusiveDimensionOperation(
- f"In slice with non-constant elements the {which} ({i}) must " +
- f"be resolved statically if it is >= 0.\nDetails: {e}")
- if i_ge_0:
- if step_gt_0:
- return core.min_dim(axis_size, i)
- else:
- return core.min_dim(axis_size - 1, i)
- else:
- if step_gt_0:
- return core.max_dim(0, axis_size + i)
- else:
- return core.max_dim(-1, axis_size + i)
-
- if s.start is None:
- start = 0 if step_gt_0 else axis_size - 1
- else:
- start = clamp_index(convert_to_index(s.start), "start")
-
- if s.stop is None:
- stop = axis_size if step_gt_0 else -1
- else:
- stop = clamp_index(convert_to_index(s.stop), "stop")
-
- gap = step if step_gt_0 else - step
- distance = (stop - start) if step_gt_0 else (start - stop)
- slice_size = core.max_dim(0, distance + gap - 1) // gap
- return start, step, slice_size
-
+@export
def blackman(M: int) -> Array:
"""Return a Blackman window of size M.
@@ -12137,6 +12351,7 @@ def blackman(M: int) -> Array:
return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1))
+@export
def bartlett(M: int) -> Array:
"""Return a Bartlett window of size M.
@@ -12167,6 +12382,7 @@ def bartlett(M: int) -> Array:
return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1)
+@export
def hamming(M: int) -> Array:
"""Return a Hamming window of size M.
@@ -12197,6 +12413,7 @@ def hamming(M: int) -> Array:
return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1))
+@export
def hanning(M: int) -> Array:
"""Return a Hanning window of size M.
@@ -12227,6 +12444,7 @@ def hanning(M: int) -> Array:
return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1)))
+@export
def kaiser(M: int, beta: ArrayLike) -> Array:
"""Return a Kaiser window of size M.
@@ -12269,6 +12487,8 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]:
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
+
+@export
@jit
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Compute the greatest common divisor of two arrays.
@@ -12315,6 +12535,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
return gcd
+@export
@jit
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Compute the least common multiple of two arrays.
@@ -12362,6 +12583,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
ufuncs.multiply(x1, ufuncs.floor_divide(x2, d)))
+@export
def extract(condition: ArrayLike, arr: ArrayLike,
*, size: int | None = None, fill_value: ArrayLike = 0) -> Array:
"""Return the elements of an array that satisfy a condition.
@@ -12423,6 +12645,7 @@ def extract(condition: ArrayLike, arr: ArrayLike,
return compress(ravel(condition), ravel(arr), size=size, fill_value=fill_value)
+@export
def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None,
*, size: int | None = None, fill_value: ArrayLike = 0, out: None = None) -> Array:
"""Compress an array along a given axis using a boolean condition.
@@ -12517,6 +12740,7 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None,
return moveaxis(result, 0, axis)
+@export
@partial(jit, static_argnames=('rowvar', 'bias', 'ddof'))
def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
bias: bool = False, ddof: int | None = None,
@@ -12675,6 +12899,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True,
return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze()
+@export
@partial(jit, static_argnames=('rowvar',))
def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array:
r"""Compute the Pearson correlation coefficients.
@@ -12804,6 +13029,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt
return comparisons.sum(dtype=dtype, axis=0)
+@export
@partial(jit, static_argnames=('side', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array:
@@ -12893,6 +13119,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
return impl(asarray(a), asarray(v), side, dtype) # type: ignore
+@export
@partial(jit, static_argnames=('right', 'method'))
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
*, method: str | None = None) -> Array:
@@ -12948,6 +13175,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
)
+@export
def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike],
funclist: list[ArrayLike | Callable[..., Array]],
*args, **kw) -> Array:
@@ -13055,6 +13283,7 @@ def _tile_to_size(arr: Array, size: int) -> Array:
return arr[:size] if arr.size > size else arr
+@export
def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
inplace: bool = True) -> Array:
"""Update array elements based on a mask.
@@ -13130,6 +13359,7 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape)
+@export
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = None, *, inplace: bool = True) -> Array:
"""Put elements into an array at given indices.
diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py
index 03f864919887..be6828c36e6a 100644
--- a/jax/_src/numpy/linalg.py
+++ b/jax/_src/numpy/linalg.py
@@ -35,10 +35,13 @@
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
-from jax._src.util import canonicalize_axis
+from jax._src.util import canonicalize_axis, set_module
from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg
+export = set_module('jax.numpy.linalg')
+
+
class EighResult(NamedTuple):
eigenvalues: jax.Array
eigenvectors: jax.Array
@@ -67,6 +70,7 @@ def _H(x: ArrayLike) -> Array:
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
+@export
@partial(jit, static_argnames=['upper'])
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
"""Compute the Cholesky decomposition of a matrix.
@@ -191,6 +195,7 @@ def svd(
...
+@export
@partial(
jit,
static_argnames=(
@@ -311,6 +316,7 @@ def svd(
)
+@export
@partial(jit, static_argnames=('n',))
def matrix_power(a: ArrayLike, n: int) -> Array:
"""Raise a square matrix to an integer power.
@@ -392,6 +398,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
return result
+@export
@jit
def matrix_rank(
M: ArrayLike, rtol: ArrayLike | None = None, *,
@@ -496,6 +503,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
return sign_diag * sign_taus, log_abs_det
+@export
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
"""
@@ -675,6 +683,7 @@ def _det_jvp(primals, tangents):
return y, jnp.trace(z, axis1=-1, axis2=-2)
+@export
@jit
def det(a: ArrayLike) -> Array:
"""
@@ -711,6 +720,7 @@ def det(a: ArrayLike) -> Array:
raise ValueError(msg.format(a_shape))
+@export
def eig(a: ArrayLike) -> tuple[Array, Array]:
"""
Compute the eigenvalues and eigenvectors of a square array.
@@ -731,7 +741,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
- This differs from :func:`numpy.linalg.eig` in that the return type of
:func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128
for 64-bit input.
- - At present, non-symmetric eigendecomposition is only implemented on the CPU backend.
+ - At present, non-symmetric eigendecomposition is only implemented on the CPU and
+ GPU backends. For more details about the GPU implementation, see the
+ documentation for :func:`jax.lax.linalg.eig`.
See also:
- :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix.
@@ -754,6 +766,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
return w, v
+@export
@jit
def eigvals(a: ArrayLike) -> Array:
"""
@@ -791,6 +804,7 @@ def eigvals(a: ArrayLike) -> Array:
compute_right_eigenvectors=False)[0]
+@export
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: str | None = None,
symmetrize_input: bool = True) -> EighResult:
@@ -846,6 +860,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
return EighResult(w, v)
+@export
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
"""
@@ -882,6 +897,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires.
+@export
def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
hermitian: bool = False, *,
rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
@@ -995,6 +1011,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents):
return p, p_dot
+@export
@jit
def inv(a: ArrayLike) -> Array:
"""Return the inverse of a square matrix
@@ -1055,6 +1072,7 @@ def inv(a: ArrayLike) -> Array:
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))
+@export
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
def norm(x: ArrayLike, ord: int | str | None = None,
axis: None | tuple[int, ...] | int = None,
@@ -1220,6 +1238,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
@overload
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ...
+@export
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
"""Compute the QR decomposition of an array
@@ -1303,6 +1322,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
return QRResult(q, r)
+@export
@jit
def solve(a: ArrayLike, b: ArrayLike) -> Array:
"""Solve a linear system of equations
@@ -1406,6 +1426,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
+@export
def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
"""
@@ -1446,6 +1467,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
return _jit_lstsq(a, b, rcond)
+@export
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
r"""Compute the cross-product of two 3D vectors
@@ -1491,6 +1513,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
return jnp.cross(x1, x2, axis=axis)
+@export
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute the outer product of two 1-dimensional arrays.
@@ -1521,6 +1544,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return x1[:, None] * x2[None, :]
+@export
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
"""Compute the norm of a matrix or stack of matrices.
@@ -1551,6 +1575,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') ->
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))
+@export
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transpose a matrix or stack of matrices.
@@ -1606,6 +1631,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))
+@export
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Compute the vector norm of a vector or batch of vectors.
@@ -1650,6 +1676,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
+@export
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
@@ -1700,6 +1727,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
preferred_element_type=preferred_element_type)
+@export
def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
@@ -1760,6 +1788,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
preferred_element_type=preferred_element_type)
+@export
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
precision: PrecisionLike = None,
@@ -1841,6 +1870,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
preferred_element_type=preferred_element_type)
+@export
def svdvals(x: ArrayLike, /) -> Array:
"""Compute the singular values of a matrix.
@@ -1865,6 +1895,7 @@ def svdvals(x: ArrayLike, /) -> Array:
return svd(x, compute_uv=False, hermitian=False)
+@export
def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
"""Extract the diagonal of an matrix or stack of matrices.
@@ -1905,6 +1936,7 @@ def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1)
+@export
def tensorinv(a: ArrayLike, ind: int = 2) -> Array:
"""Compute the tensor inverse of an array.
@@ -1947,6 +1979,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array:
return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape)
+@export
def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array:
"""Solve the tensor equation a x = b for x.
@@ -1996,6 +2029,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None)
return solve(a_arr, b_arr.ravel()).reshape(out_shape)
+@export
def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array:
"""Efficiently compute matrix products between a sequence of arrays.
@@ -2088,6 +2122,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -
optimize='optimal', precision=precision)
+@export
@partial(jit, static_argnames=['p'])
def cond(x: ArrayLike, p=None):
"""Compute the condition number of a matrix.
@@ -2147,6 +2182,7 @@ def cond(x: ArrayLike, p=None):
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
+@export
def trace(x: ArrayLike, /, *,
offset: int = 0, dtype: DTypeLike | None = None) -> Array:
"""Compute the trace of a matrix.
diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py
index 10cc90575cef..19388b903e5d 100644
--- a/jax/_src/numpy/polynomial.py
+++ b/jax/_src/numpy/polynomial.py
@@ -33,6 +33,10 @@
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
from jax._src.typing import Array, ArrayLike
+from jax._src.util import set_module
+
+
+export = set_module('jax.numpy')
@jit
@@ -57,6 +61,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
+@export
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
r"""Returns the roots of a polynomial given the coefficients ``p``.
@@ -116,6 +121,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
return _roots_with_zeros(p_arr, num_leading_zeros)
+@export
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
full: bool = False, w: ArrayLike | None = None, cov: bool = False
@@ -287,6 +293,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
return c
+@export
@jit
def poly(seq_of_zeros: ArrayLike) -> Array:
r"""Returns the coefficients of a polynomial for the given sequence of roots.
@@ -369,6 +376,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
return a
+@export
@partial(jit, static_argnames=['unroll'])
def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
r"""Evaluates the polynomial at specific values.
@@ -432,6 +440,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
return y
+@export
@jit
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the sum of the two polynomials.
@@ -489,6 +498,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr)
+@export
@partial(jit, static_argnames=('m',))
def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array:
r"""Returns the coefficients of the integration of specified order of a polynomial.
@@ -557,6 +567,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array
return true_divide(concatenate((p_arr, k_arr)), coeff)
+@export
@partial(jit, static_argnames=('m',))
def polyder(p: ArrayLike, m: int = 1) -> Array:
r"""Returns the coefficients of the derivative of specified order of a polynomial.
@@ -607,6 +618,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array:
return p_arr[:-m] * coeff[::-1]
+@export
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
r"""Returns the product of two polynomials.
@@ -673,6 +685,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
return convolve(a1_arr, a2_arr, mode='full')
+@export
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
r"""Returns the quotient and remainder of polynomial division.
@@ -732,6 +745,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
return q, u_arr
+@export
@jit
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the difference of two polynomials.
diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py
index 08f11d0cb6ad..69d6843f5155 100644
--- a/jax/_src/numpy/reductions.py
+++ b/jax/_src/numpy/reductions.py
@@ -37,9 +37,11 @@
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
- NumpyComplexWarning)
+ set_module, NumpyComplexWarning)
+export = set_module('jax.numpy')
+
_all = builtins.all
_lax_const = lax_internal._const
@@ -82,7 +84,7 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
ReductionOp = Callable[[Any, Any], Any]
-def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
+def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
*, has_identity: bool = True,
preproc: Callable[[ArrayLike], ArrayLike] | None = None,
bool_op: ReductionOp | None = None,
@@ -215,13 +217,14 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
- return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
+ return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)
+@export
def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
@@ -296,17 +299,19 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
promote_integers=promote_integers)
+
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array:
- return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric,
+ return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, promote_integers=promote_integers)
+@export
def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
@@ -386,11 +391,12 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
- return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
+ return _reduction(a, "max", lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
+@export
def max(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@@ -468,11 +474,12 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None,
def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
- return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
+ return _reduction(a, "min", lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
+@export
def min(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@@ -548,10 +555,11 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None,
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
- return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
+ return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
+@export
def all(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
r"""Test whether all array elements along a given axis evaluate to True.
@@ -604,10 +612,11 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None,
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
- return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
+ return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
+@export
def any(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
r"""Test whether any of the array elements along a given axis evaluate to True.
@@ -664,7 +673,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
arr = lax_internal.asarray(a)
init_val = np.array(-1, dtype=dtype or arr.dtype)
- return _reduction(arr, name="reduce_bitwise_and", np_fun=None, op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
+ return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@@ -673,7 +682,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
- return _reduction(a, name="reduce_bitwise_or", np_fun=None, op=lax.bitwise_or, init_val=0, preproc=_require_integer,
+ return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@@ -682,7 +691,7 @@ def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
- return _reduction(a, name="reduce_bitwise_xor", np_fun=None, op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
+ return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@@ -691,7 +700,7 @@ def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
- return _reduction(a, name="reduce_logical_and", np_fun=None, op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
+ return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@@ -700,7 +709,7 @@ def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
- return _reduction(a, name="reduce_logical_or", np_fun=None, op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
+ return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@@ -709,11 +718,44 @@ def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
- return _reduction(a, name="reduce_logical_xor", np_fun=None, op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
+ return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
+def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
+ out: None = None, keepdims: bool = False,
+ initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
+ """Compute log(sum(exp(a))) while avoiding precision loss."""
+ if out is not None:
+ raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.")
+ dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce")
+ a_arr, = promote_dtypes_inexact(a)
+ pos_dims, dims = _reduction_dims(a_arr, axis)
+ amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
+ amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
+ amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
+ exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
+ sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where)
+ result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype))
+ return result if initial is None else lax.logaddexp(initial, result)
+
+
+def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
+ out: None = None, keepdims: bool = False,
+ initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
+ """Compute log2(sum(2 ** a)) via logsumexp."""
+ if out is not None:
+ raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.")
+ dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce")
+ ln2 = float(np.log(2))
+ if initial is not None:
+ initial *= ln2
+ return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims,
+ where=where, initial=initial) / ln2
+
+
+@export
def amin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@@ -721,6 +763,7 @@ def amin(a: ArrayLike, axis: Axis = None, out: None = None,
return min(a, axis=axis, out=out, keepdims=keepdims,
initial=initial, where=where)
+@export
def amax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
@@ -740,6 +783,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]):
return size
+@export
def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
@@ -843,6 +887,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, *
@overload
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ...
+@export
def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]:
"""Compute the weighed average.
@@ -953,6 +998,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
return avg
+@export
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
@@ -1093,6 +1139,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
return _upcast_f16(computation_dtype), np.dtype(dtype)
+@export
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
@@ -1185,6 +1232,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))
+@export
def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array:
r"""Return the peak-to-peak range along a given axis.
@@ -1236,6 +1284,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
return lax.sub(x, y)
+@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
@@ -1295,6 +1344,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
return out
+@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@@ -1377,6 +1427,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None,
initial=initial, where=where)
+@export
@partial(api.jit, static_argnames=('axis', 'keepdims'))
def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@@ -1459,6 +1510,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None,
initial=initial, where=where)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@@ -1542,6 +1594,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
initial=initial, where=where)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
@@ -1625,6 +1678,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
initial=initial, where=where)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
keepdims: bool = False, where: ArrayLike | None = None) -> Array:
@@ -1716,6 +1770,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
return td
+@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
@@ -1818,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
return lax.convert_element_type(result, dtype)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None,
ddof: int = 0, keepdims: bool = False,
@@ -1939,6 +1995,7 @@ def _cumulative_reduction(
return result
+@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def cumsum(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@@ -1975,6 +2032,7 @@ def cumsum(a: ArrayLike, axis: int | None = None,
return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def cumprod(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@@ -2010,6 +2068,7 @@ def cumprod(a: ArrayLike, axis: int | None = None,
return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def nancumsum(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@@ -2059,6 +2118,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None,
fill_nan=True, fill_value=0)
+@export
@partial(api.jit, static_argnames=('axis', 'dtype'))
def nancumprod(a: ArrayLike, axis: int | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
@@ -2115,6 +2175,7 @@ def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None,
a, axis, dtype, out, promote_integers=True)
+@export
def cumulative_sum(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
@@ -2176,6 +2237,7 @@ def cumulative_sum(
return out
+@export
def cumulative_prod(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
@@ -2239,6 +2301,7 @@ def cumulative_prod(
# Quantiles
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
+@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
@@ -2295,6 +2358,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False)
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
+@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
@@ -2475,7 +2539,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)
+
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
+@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
@@ -2531,7 +2597,9 @@ def percentile(a: ArrayLike, q: ArrayLike,
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
+
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
+@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
@@ -2591,6 +2659,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
method=method, keepdims=keepdims)
+@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,
@@ -2642,6 +2711,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
keepdims=keepdims, method='midpoint')
+@export
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False,
diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py
index 6491a7617d8d..0d5ea905becc 100644
--- a/jax/_src/numpy/setops.py
+++ b/jax/_src/numpy/setops.py
@@ -35,10 +35,12 @@
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import check_arraylike, promote_dtypes
-from jax._src.util import canonicalize_axis
+from jax._src.util import canonicalize_axis, set_module
from jax._src.typing import Array, ArrayLike
+export = set_module('jax.numpy')
+
_lax_const = lax_internal._const
@@ -88,6 +90,7 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]:
return arr, num_unique1 + num_unique2
+@export
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set difference of two 1D arrays.
@@ -175,6 +178,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
+@export
def union1d(ar1: ArrayLike, ar2: ArrayLike,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set union of two 1D arrays.
@@ -278,6 +282,7 @@ def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *,
return where(arange(len(vals)) < num_results, vals, fill_value)
+@export
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *,
size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set-wise xor of elements in two arrays.
@@ -417,6 +422,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as
return vals
+@export
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]:
@@ -524,6 +530,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return int1d
+@export
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: bool = False, invert: bool = False, *,
method='auto') -> Array:
@@ -652,6 +659,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
return ret[0] if len(ret) == 1 else ret
+@export
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, axis: int | None = None,
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
@@ -863,6 +871,7 @@ class _UniqueInverseResult(NamedTuple):
inverse_indices: Array
+@export
def unique_all(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueAllResult:
"""Return unique values from x, along with indices, inverse indices, and counts.
@@ -945,6 +954,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
+@export
def unique_counts(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueCountsResult:
"""Return unique values from x, along with counts.
@@ -1005,6 +1015,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
return _UniqueCountsResult(values=values, counts=counts)
+@export
def unique_inverse(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueInverseResult:
"""Return unique values from x, along with indices, inverse indices, and counts.
@@ -1070,6 +1081,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None,
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
+@export
def unique_values(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> Array:
"""Return unique values from x, along with indices, inverse indices, and counts.
diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py
index 27e2973b212b..5dbd67e62a9f 100644
--- a/jax/_src/numpy/ufunc_api.py
+++ b/jax/_src/numpy/ufunc_api.py
@@ -33,6 +33,8 @@
import numpy as np
+export = set_module("jax.numpy")
+
_AT_INPLACE_WARNING = """\
Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like
np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
@@ -40,7 +42,7 @@
"""
-@set_module('jax.numpy')
+@export
class ufunc:
"""Universal functions which operation element-by-element on arrays.
@@ -586,6 +588,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array:
return result.reshape(*np.shape(A), *np.shape(B))
+@export
def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
*, identity: Any = None) -> ufunc:
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py
index 93e116fa4b6a..de8688e491ba 100644
--- a/jax/_src/numpy/ufuncs.py
+++ b/jax/_src/numpy/ufuncs.py
@@ -38,6 +38,10 @@
promote_shapes, _where, check_no_float0s)
from jax._src.numpy.ufunc_api import ufunc
from jax._src.numpy import reductions
+from jax._src.util import set_module
+
+
+export = set_module('jax.numpy')
_lax_const = lax._const
@@ -75,6 +79,7 @@ def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc:
return decorator
+@export
@partial(jit, inline=True)
def fabs(x: ArrayLike, /) -> Array:
"""Compute the element-wise absolute values of the real-valued input.
@@ -119,18 +124,21 @@ def fabs(x: ArrayLike, /) -> Array:
return lax.abs(*promote_args_inexact('fabs', x))
+@export
@partial(jit, inline=True)
def bitwise_invert(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.invert`."""
return lax.bitwise_not(*promote_args('bitwise_invert', x))
+@export
@partial(jit, inline=True)
def bitwise_not(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.invert`."""
return lax.bitwise_not(*promote_args('bitwise_not', x))
+@export
@partial(jit, inline=True)
def invert(x: ArrayLike, /) -> Array:
"""Compute the bitwise inversion of an input.
@@ -223,6 +231,7 @@ def negative(x: ArrayLike, /) -> Array:
return lax.neg(*promote_args('negative', x))
+@export
@partial(jit, inline=True)
def positive(x: ArrayLike, /) -> Array:
"""Return element-wise positive values of the input.
@@ -271,6 +280,7 @@ def positive(x: ArrayLike, /) -> Array:
return lax.asarray(*promote_args('positive', x))
+@export
@partial(jit, inline=True)
def sign(x: ArrayLike, /) -> Array:
r"""Return an element-wise indication of sign of the input.
@@ -321,6 +331,7 @@ def sign(x: ArrayLike, /) -> Array:
return lax.sign(*promote_args('sign', x))
+@export
@partial(jit, inline=True)
def floor(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer downwards.
@@ -359,6 +370,7 @@ def floor(x: ArrayLike, /) -> Array:
return lax.floor(*promote_args_inexact('floor', x))
+@export
@partial(jit, inline=True)
def ceil(x: ArrayLike, /) -> Array:
"""Round input to the nearest integer upwards.
@@ -397,6 +409,7 @@ def ceil(x: ArrayLike, /) -> Array:
return lax.ceil(*promote_args_inexact('ceil', x))
+@export
@partial(jit, inline=True)
def exp(x: ArrayLike, /) -> Array:
"""Calculate element-wise exponential of the input.
@@ -438,6 +451,7 @@ def exp(x: ArrayLike, /) -> Array:
return lax.exp(*promote_args_inexact('exp', x))
+@export
@partial(jit, inline=True)
def log(x: ArrayLike, /) -> Array:
"""Calculate element-wise natural logarithm of the input.
@@ -475,6 +489,7 @@ def log(x: ArrayLike, /) -> Array:
return lax.log(*promote_args_inexact('log', x))
+@export
@partial(jit, inline=True)
def expm1(x: ArrayLike, /) -> Array:
"""Calculate ``exp(x)-1`` of each element of the input.
@@ -519,6 +534,7 @@ def expm1(x: ArrayLike, /) -> Array:
return lax.expm1(*promote_args_inexact('expm1', x))
+@export
@partial(jit, inline=True)
def log1p(x: ArrayLike, /) -> Array:
"""Calculates element-wise logarithm of one plus input, ``log(x+1)``.
@@ -559,6 +575,7 @@ def log1p(x: ArrayLike, /) -> Array:
return lax.log1p(*promote_args_inexact('log1p', x))
+@export
@partial(jit, inline=True)
def sin(x: ArrayLike, /) -> Array:
"""Compute a trigonometric sine of each element of input.
@@ -590,6 +607,7 @@ def sin(x: ArrayLike, /) -> Array:
return lax.sin(*promote_args_inexact('sin', x))
+@export
@partial(jit, inline=True)
def cos(x: ArrayLike, /) -> Array:
"""Compute a trigonometric cosine of each element of input.
@@ -620,6 +638,7 @@ def cos(x: ArrayLike, /) -> Array:
return lax.cos(*promote_args_inexact('cos', x))
+@export
@partial(jit, inline=True)
def tan(x: ArrayLike, /) -> Array:
"""Compute a trigonometric tangent of each element of input.
@@ -650,6 +669,7 @@ def tan(x: ArrayLike, /) -> Array:
return lax.tan(*promote_args_inexact('tan', x))
+@export
@partial(jit, inline=True)
def arcsin(x: ArrayLike, /) -> Array:
r"""Compute element-wise inverse of trigonometric sine of input.
@@ -691,6 +711,7 @@ def arcsin(x: ArrayLike, /) -> Array:
return lax.asin(*promote_args_inexact('arcsin', x))
+@export
@partial(jit, inline=True)
def arccos(x: ArrayLike, /) -> Array:
"""Compute element-wise inverse of trigonometric cosine of input.
@@ -733,6 +754,7 @@ def arccos(x: ArrayLike, /) -> Array:
return lax.acos(*promote_args_inexact('arccos', x))
+@export
@partial(jit, inline=True)
def arctan(x: ArrayLike, /) -> Array:
"""Compute element-wise inverse of trigonometric tangent of input.
@@ -773,6 +795,7 @@ def arctan(x: ArrayLike, /) -> Array:
return lax.atan(*promote_args_inexact('arctan', x))
+@export
@partial(jit, inline=True)
def sinh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic sine of input.
@@ -827,6 +850,7 @@ def sinh(x: ArrayLike, /) -> Array:
return lax.sinh(*promote_args_inexact('sinh', x))
+@export
@partial(jit, inline=True)
def cosh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic cosine of input.
@@ -880,6 +904,7 @@ def cosh(x: ArrayLike, /) -> Array:
return lax.cosh(*promote_args_inexact('cosh', x))
+@export
@partial(jit, inline=True)
def arcsinh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic sine of input.
@@ -929,6 +954,7 @@ def arcsinh(x: ArrayLike, /) -> Array:
return lax.asinh(*promote_args_inexact('arcsinh', x))
+@export
@jit
def arccosh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic cosine of input.
@@ -984,6 +1010,7 @@ def arccosh(x: ArrayLike, /) -> Array:
return result
+@export
@partial(jit, inline=True)
def tanh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic tangent of input.
@@ -1037,6 +1064,7 @@ def tanh(x: ArrayLike, /) -> Array:
return lax.tanh(*promote_args_inexact('tanh', x))
+@export
@partial(jit, inline=True)
def arctanh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic tangent of input.
@@ -1085,6 +1113,7 @@ def arctanh(x: ArrayLike, /) -> Array:
return lax.atanh(*promote_args_inexact('arctanh', x))
+@export
@partial(jit, inline=True)
def sqrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise non-negative square root of the input array.
@@ -1117,6 +1146,7 @@ def sqrt(x: ArrayLike, /) -> Array:
return lax.sqrt(*promote_args_inexact('sqrt', x))
+@export
@partial(jit, inline=True)
def cbrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise cube root of the input array.
@@ -1144,6 +1174,7 @@ def cbrt(x: ArrayLike, /) -> Array:
"""
return lax.cbrt(*promote_args_inexact('cbrt', x))
+
def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array:
"""Implementation of jnp.add.at."""
if a.dtype == bool:
@@ -1152,6 +1183,7 @@ def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array:
return a.at[indices].add(b).astype(bool)
return a.at[indices].add(b)
+
@binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at)
def add(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Add two arrays element-wise.
@@ -1182,6 +1214,7 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array:
x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
+
def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array:
"""Implementation of jnp.multiply.at."""
if a.dtype == bool:
@@ -1191,6 +1224,7 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array:
else:
return a.at[indices].mul(b)
+
@binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at)
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Multiply two arrays element-wise.
@@ -1221,6 +1255,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
+
@binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and)
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise AND operation elementwise.
@@ -1250,6 +1285,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
+
@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or)
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise OR operation elementwise.
@@ -1279,6 +1315,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
+
@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor)
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Compute the bitwise XOR operation elementwise.
@@ -1309,6 +1346,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
+@export
@partial(jit, inline=True)
def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise.
@@ -1364,12 +1402,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.shift_left(*promote_args_numeric("left_shift", x, y))
+@export
@partial(jit, inline=True)
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.left_shift`."""
return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y))
+@export
@partial(jit, inline=True)
def equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Returns element-wise truth value of ``x == y``.
@@ -1419,6 +1459,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.eq(*promote_args("equal", x, y))
+@export
@partial(jit, inline=True)
def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Returns element-wise truth value of ``x != y``.
@@ -1472,6 +1513,7 @@ def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array:
"""Implementation of jnp.subtract.at."""
return a.at[indices].subtract(b)
+
@binary_ufunc(identity=None, at=_subtract_at)
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Subtract two arrays element-wise.
@@ -1502,6 +1544,7 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.sub(*promote_args("subtract", x, y))
+@export
@partial(jit, inline=True)
def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Compute the arctangent of x1/x2, choosing the correct quadrant.
@@ -1557,6 +1600,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.atan2(*promote_args_inexact("arctan2", x1, x2))
+@export
@partial(jit, inline=True)
def minimum(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise minimum of the input arrays.
@@ -1617,6 +1661,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.min(*promote_args("minimum", x, y))
+@export
@partial(jit, inline=True)
def maximum(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise maximum of the input arrays.
@@ -1676,6 +1721,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.max(*promote_args("maximum", x, y))
+@export
@partial(jit, inline=True)
def float_power(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Calculate element-wise base ``x`` exponential of ``y``.
@@ -1722,6 +1768,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.pow(*promote_args_inexact("float_power", x, y))
+@export
@partial(jit, inline=True)
def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise next floating point value after ``x`` towards ``y``.
@@ -1749,6 +1796,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
+@export
@partial(jit, inline=True)
def spacing(x: ArrayLike, /) -> Array:
"""Return the spacing between ``x`` and the next adjacent number.
@@ -1856,6 +1904,7 @@ def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
+@export
@partial(jit, inline=True)
def logical_not(x: ArrayLike, /) -> Array:
"""Compute NOT bool(x) element-wise.
@@ -1901,6 +1950,8 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array],
lax_op(x.real, y.real))
return lax_op(x, y)
+
+@export
@partial(jit, inline=True)
def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x >= y``.
@@ -1946,6 +1997,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y))
+@export
@partial(jit, inline=True)
def greater(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x > y``.
@@ -1992,6 +2044,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array:
return _complex_comparison(lax.gt, *promote_args("greater", x, y))
+@export
@partial(jit, inline=True)
def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x <= y``.
@@ -2038,6 +2091,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
return _complex_comparison(lax.le, *promote_args("less_equal", x, y))
+@export
@partial(jit, inline=True)
def less(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Return element-wise truth value of ``x < y``.
@@ -2083,42 +2137,58 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return _complex_comparison(lax.lt, *promote_args("less", x, y))
+
# Array API aliases
+@export
@partial(jit, inline=True)
def acos(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arccos`"""
return arccos(*promote_args('acos', x))
+
+@export
@partial(jit, inline=True)
def acosh(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arccosh`"""
return arccosh(*promote_args('acosh', x))
+
+@export
@partial(jit, inline=True)
def asin(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arcsin`"""
return arcsin(*promote_args('asin', x))
+
+@export
@partial(jit, inline=True)
def asinh(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arcsinh`"""
return arcsinh(*promote_args('asinh', x))
+
+@export
@partial(jit, inline=True)
def atan(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arctan`"""
return arctan(*promote_args('atan', x))
+
+@export
@partial(jit, inline=True)
def atanh(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arctanh`"""
return arctanh(*promote_args('atanh', x))
+
+@export
@partial(jit, inline=True)
def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.arctan2`"""
return arctan2(*promote_args('atan2', x1, x2))
+
+@export
@jit
def bitwise_count(x: ArrayLike, /) -> Array:
r"""Counts the number of 1 bits in the binary representation of the absolute value
@@ -2154,6 +2224,8 @@ def bitwise_count(x: ArrayLike, /) -> Array:
# Following numpy we take the absolute value and return uint8.
return lax.population_count(abs(x)).astype('uint8')
+
+@export
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Right shift the bits of ``x1`` to the amount specified in ``x2``.
@@ -2205,12 +2277,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax_fn(x1, x2)
+@export
@partial(jit, inline=True)
def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.right_shift`."""
return right_shift(x1, x2)
+@export
@partial(jit, inline=True)
def absolute(x: ArrayLike, /) -> Array:
r"""Calculate the absolute value element-wise.
@@ -2246,12 +2320,14 @@ def absolute(x: ArrayLike, /) -> Array:
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
+@export
@partial(jit, inline=True)
def abs(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.absolute`."""
return absolute(x)
+@export
@jit
def rint(x: ArrayLike, /) -> Array:
"""Rounds the elements of x to the nearest integer
@@ -2291,6 +2367,7 @@ def rint(x: ArrayLike, /) -> Array:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
+@export
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Copies the sign of each element in ``x2`` to the corresponding element in ``x1``.
@@ -2330,6 +2407,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
+@export
@partial(jit, inline=True)
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculates the division of x1 by x2 element-wise
@@ -2368,11 +2446,13 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.div(x1, x2)
+@export
def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.true_divide`."""
return true_divide(x1, x2)
+@export
@jit
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculates the floor division of x1 by x2 element-wise
@@ -2427,6 +2507,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _float_divmod(x1, x2)[0]
+@export
@jit
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]:
"""Calculates the integer quotient and remainder of x1 by x2 element-wise
@@ -2481,6 +2562,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
return lax.round(div), mod
+@export
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculate element-wise base ``x1`` exponential of ``x2``.
@@ -2565,6 +2647,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# Handle cases #2 and #3 under a jit:
return _power(x1, x2)
+@export
def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.power`"""
return power(x1, x2)
@@ -2604,7 +2687,7 @@ def _pow_int_int(x1, x2):
return acc
-@jit
+@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp)
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute ``log(exp(x1) + exp(x2))`` avoiding overflow.
@@ -2630,17 +2713,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax_other.logaddexp(x1, x2)
-def _wrap_between(x, _a):
- """Wraps `x` between `[-a, a]`."""
- a = _constant_like(x, _a)
- two_a = _constant_like(x, 2 * _a)
- zero = _constant_like(x, 0)
- rem = lax.rem(lax.add(x, a), two_a)
- rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
- return lax.sub(rem, a)
-
-
-@jit
+@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp2)
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
@@ -2668,35 +2741,11 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
Array(True, dtype=bool)
"""
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
- return _logaddexp2(x1, x2)
-
-
-@custom_jvp
-def _logaddexp2(x1, x2):
- amax = lax.max(x1, x2)
- if dtypes.issubdtype(x1.dtype, np.floating):
- delta = lax.sub(x1, x2)
- return lax.select(lax._isnan(delta),
- lax.add(x1, x2), # NaNs or infinities of the same sign.
- lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
- _constant_like(x1, np.log(2)))))
- else:
- delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
- out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
- return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
-
-
-@_logaddexp2.defjvp
-def _logaddexp2_jvp(primals, tangents):
- x1, x2 = primals
- t1, t2 = tangents
- x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
- primal_out = logaddexp2(x1, x2)
- tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
- lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
- return primal_out, tangent_out
+ ln2 = float(np.log(2))
+ return logaddexp(x1 * ln2, x2 * ln2) / ln2
+@export
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
"""Calculates the base-2 logarithm of ``x`` element-wise.
@@ -2719,6 +2768,7 @@ def log2(x: ArrayLike, /) -> Array:
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
+@export
@partial(jit, inline=True)
def log10(x: ArrayLike, /) -> Array:
"""Calculates the base-10 logarithm of x element-wise
@@ -2742,6 +2792,7 @@ def log10(x: ArrayLike, /) -> Array:
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
+@export
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
"""Calculate element-wise base-2 exponential of input.
@@ -2776,6 +2827,7 @@ def exp2(x: ArrayLike, /) -> Array:
return lax.exp2(x)
+@export
@jit
def signbit(x: ArrayLike, /) -> Array:
"""Return the sign bit of array elements.
@@ -2848,6 +2900,7 @@ def _normalize_float(x):
return lax.bitcast_convert_type(x1, int_type), x2
+@export
@jit
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute x1 * 2 ** x2
@@ -2897,6 +2950,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(isinf(x1) | (x1 == 0), x1, x)
+@export
@jit
def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
"""Split floating point values into mantissa and twos exponent.
@@ -2950,6 +3004,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]:
return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
+@export
@jit
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Returns element-wise remainder of the division.
@@ -2997,11 +3052,13 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
+@export
def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.remainder`"""
return remainder(x1, x2)
+@export
@jit
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Calculate element-wise floating-point modulo operation.
@@ -3043,6 +3100,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.rem(*promote_args_numeric("fmod", x1, x2))
+@export
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
"""Calculate element-wise square of the input array.
@@ -3092,6 +3150,7 @@ def square(x: ArrayLike, /) -> Array:
return lax.square(x)
+@export
@partial(jit, inline=True)
def deg2rad(x: ArrayLike, /) -> Array:
r"""Convert angles from degrees to radians.
@@ -3126,6 +3185,7 @@ def deg2rad(x: ArrayLike, /) -> Array:
return lax.mul(x, _lax_const(x, np.pi / 180))
+@export
@partial(jit, inline=True)
def rad2deg(x: ArrayLike, /) -> Array:
r"""Convert angles from radians to degrees.
@@ -3161,15 +3221,19 @@ def rad2deg(x: ArrayLike, /) -> Array:
return lax.mul(x, _lax_const(x, 180 / np.pi))
+@export
def degrees(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.rad2deg`"""
return rad2deg(x)
+
+@export
def radians(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.deg2rad`"""
return deg2rad(x)
+@export
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
"""Return element-wise complex-conjugate of the input.
@@ -3199,11 +3263,13 @@ def conjugate(x: ArrayLike, /) -> Array:
return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
+@export
def conj(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.conjugate`"""
return conjugate(x)
+@export
@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
"""Return element-wise imaginary of part of the complex argument.
@@ -3235,6 +3301,7 @@ def imag(val: ArrayLike, /) -> Array:
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
+@export
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
"""Return element-wise real part of the complex argument.
@@ -3266,6 +3333,7 @@ def real(val: ArrayLike, /) -> Array:
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
+@export
@jit
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
"""Return element-wise fractional and integral parts of the input array.
@@ -3299,6 +3367,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
return x - whole, whole
+@export
@partial(jit, inline=True)
def isfinite(x: ArrayLike, /) -> Array:
"""Return a boolean array indicating whether each element of input is finite.
@@ -3339,6 +3408,7 @@ def isfinite(x: ArrayLike, /) -> Array:
return lax.full_like(x, True, dtype=np.bool_)
+@export
@jit
def isinf(x: ArrayLike, /) -> Array:
"""Return a boolean array indicating whether each element of input is infinite.
@@ -3394,6 +3464,7 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
return lax.full_like(x, False, dtype=np.bool_)
+@export
def isposinf(x, /, out=None):
"""
Return boolean array indicating whether each element of input is positive infinite.
@@ -3427,6 +3498,7 @@ def isposinf(x, /, out=None):
return _isposneginf(np.inf, x, out)
+@export
def isneginf(x, /, out=None):
"""
Return boolean array indicating whether each element of input is negative infinite.
@@ -3460,6 +3532,7 @@ def isneginf(x, /, out=None):
return _isposneginf(-np.inf, x, out)
+@export
@partial(jit, inline=True)
def isnan(x: ArrayLike, /) -> Array:
"""Returns a boolean array indicating whether each element of input is ``NaN``.
@@ -3494,6 +3567,7 @@ def isnan(x: ArrayLike, /) -> Array:
return lax.ne(x, x)
+@export
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Compute the heaviside step function.
@@ -3543,6 +3617,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
+@export
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""
@@ -3591,6 +3666,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return _where(idx_inf, _lax_const(x, np.inf), x)
+@export
@partial(jit, inline=True)
def reciprocal(x: ArrayLike, /) -> Array:
"""Calculate element-wise reciprocal of the input.
@@ -3624,6 +3700,7 @@ def reciprocal(x: ArrayLike, /) -> Array:
return lax.integer_pow(x, -1)
+@export
@jit
def sinc(x: ArrayLike, /) -> Array:
r"""Calculate the normalized sinc function.
diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py
index e7a0e2142327..f1e6d399b97b 100644
--- a/jax/_src/numpy/vectorize.py
+++ b/jax/_src/numpy/vectorize.py
@@ -23,9 +23,11 @@
from jax._src import config
from jax import lax
from jax._src.numpy import lax_numpy as jnp
-from jax._src.util import safe_map as map, safe_zip as zip
+from jax._src.util import set_module, safe_map as map, safe_zip as zip
+export = set_module('jax.numpy')
+
# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
_DIMENSION_NAME = r'\w+'
_CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME)
@@ -185,6 +187,7 @@ def new_func(*args, **kwargs):
return new_func, dynamic_args, dynamic_kwargs
+@export
def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
"""Define a vectorized function with broadcasting.
diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py
index 72ed07674f1f..acbf0d4f7ed5 100644
--- a/jax/_src/pallas/core.py
+++ b/jax/_src/pallas/core.py
@@ -219,6 +219,10 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any):
def __repr__(self) -> str:
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
+ @property
+ def sharding(self):
+ return self.inner_aval.sharding
+
def update_weak_type(self, weak_type):
return AbstractMemoryRef(
self.inner_aval.update_weak_type(weak_type), self.memory_space)
@@ -873,7 +877,7 @@ def get_grid_mapping(
)
# The inputs for the index maps
index_map_avals = (
- (index_map_grid_aval,) * len(grid_spec.grid))
+ (index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid))
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)
diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py
index 1bcf704b3579..b83c36159555 100644
--- a/jax/_src/pallas/cost_estimate.py
+++ b/jax/_src/pallas/cost_estimate.py
@@ -16,9 +16,12 @@
import math
from typing import Any, Sequence
+import jax
from jax._src import core as jax_core
-from jax._src.pallas import core as pallas_core
+from jax._src import custom_derivatives
from jax._src import linear_util as lu
+from jax._src import pjit
+from jax._src.pallas import core as pallas_core
from jax._src.interpreters import partial_eval as pe
from jax._src.util import safe_map
from jax._src.util import safe_zip
@@ -71,22 +74,28 @@ def cost_estimate_jaxpr(
bytes_accessed=total_cost.bytes_accessed,
)
-def cost_estimate(fun, *args) -> pallas_core.CostEstimate:
+def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
"""Computes a cost estimate for the given function.
Args:
fun: The function to compute the cost estimate for.
*args: The arguments to the function. Can be jax.ShapeDtypeStruct or
jax.Array.
+ **kwargs: The keyword arguments to the function.
Returns:
A pallas_core.CostEstimate object containing the cost estimate.
"""
- wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),))
- avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args]
+ flattened_args, treedef = jax.tree.flatten(args)
+ def _partial_fun(*flat_args):
+ return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs)
+ wrapped_fun = lu.wrap_init(
+ lambda *args, **kwargs: (_partial_fun(*args, **kwargs),))
+ avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))
- input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args)
+ input_bytes = sum(
+ math.prod(a.shape) * a.dtype.itemsize for a in flattened_args)
output_bytes = sum(
math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars)
return pallas_core.CostEstimate(
@@ -213,3 +222,24 @@ def dot_general_cost_rule(ctx: Context,
bytes_accessed=0,
)
register_cost_rule(lax.dot_general_p, dot_general_cost_rule)
+
+# Higher-order primitives
+def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_):
+ del ctx
+ inner_cost = cost_estimate_jaxpr(jaxpr)
+ return CostEstimate(
+ flops=inner_cost.flops,
+ transcendentals=inner_cost.transcendentals,
+ bytes_accessed=inner_cost.bytes_accessed,
+ )
+register_cost_rule(pjit.pjit_p, _pjit_cost_rule)
+
+def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_):
+ del ctx
+ inner_cost = cost_estimate_jaxpr(fun_jaxpr)
+ return CostEstimate(
+ flops=inner_cost.flops,
+ transcendentals=inner_cost.transcendentals,
+ bytes_accessed=inner_cost.bytes_accessed,
+ )
+register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule)
diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py
index 3dbb410be29f..1f0062cad0f9 100644
--- a/jax/_src/pallas/mosaic/lowering.py
+++ b/jax/_src/pallas/mosaic/lowering.py
@@ -42,6 +42,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
+from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
@@ -1315,12 +1316,20 @@ def _masked_swap_lowering_rule(
ctx: LoweringRuleContext, *args_flat, args_tree, **_
):
ref, transforms, val, mask = args_tree.unflatten(args_flat)
- ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in)
+ ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten(
+ ctx.avals_in
+ )
(*prev_transforms, idx) = transforms
(*_, idx_aval) = transforms_avals
if mask is not None:
- raise NotImplementedError
+ if val_aval.dtype.itemsize != 4:
+ raise NotImplementedError("masked swap with non-32-bit data")
+ if val_aval.shape != mask_aval.shape:
+ raise ValueError(
+ "Expected value and mask to have the same shape, but got"
+ f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}."
+ )
ref_block_shape, *_ = ctx.block_shapes
ref, ref_block_shape = _transform_ref(
@@ -1351,6 +1360,8 @@ def _masked_swap_lowering_rule(
need_stride = not all((s is None or s == 1) for s in strides)
if is_smem_store:
+ if mask is not None:
+ raise ValueError("SMEM store does not support masks")
if val_aval.shape:
raise ValueError("Can only store scalars to SMEM")
result = memref.load(ref, starts)
@@ -1380,7 +1391,7 @@ def _masked_swap_lowering_rule(
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
for b in ref_block_shape
]
- mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
+ mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None)
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
_dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True))
if need_stride:
@@ -1399,9 +1410,15 @@ def _masked_swap_lowering_rule(
result = _maybe_cast_load_to_bool(val_aval, result)
if need_stride:
+ if mask is not None:
+ raise NotImplementedError("masked swap with strided store")
tpu.StridedStoreOp(val, ref, starts, strides)
- else:
+ elif jaxlib_version <= (0, 4, 35):
+ if mask is not None:
+ raise NotImplementedError("masked swap with vector store")
vector.StoreOp(val, ref, starts)
+ else:
+ tpu.VectorStoreOp(val, ref, starts, [], mask=mask)
return result
@@ -3243,3 +3260,68 @@ def _lower_fun(shape):
lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering
+
+
+def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
+ operand, padding_value = args
+ padding_config = kwargs["padding_config"]
+
+ out_type: ir.VectorType = aval_to_ir_type(ctx.avals_in[0])
+ if not isinstance(out_type, ir.VectorType):
+ raise NotImplementedError("Only vector types are supported.")
+
+ for axis, (low, high, interior) in enumerate(padding_config):
+ if low == 0 and high == 0 and interior == 0:
+ continue
+
+ def _pad(val):
+ shape = list(operand.type.shape)
+ shape[axis] = val
+ pad_vec_type = ir.VectorType.get(
+ shape,
+ operand.type.element_type,
+ )
+
+ if isinstance(padding_value, ir.OpResult):
+ pad = vector.BroadcastOp(
+ pad_vec_type,
+ padding_value,
+ ).result
+ else:
+ scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value)
+ pad = arith.ConstantOp(
+ pad_vec_type,
+ ir.DenseElementsAttr.get_splat(
+ pad_vec_type,
+ scalar_attr,
+ ),
+ ).result
+ return pad
+
+ if low != 0:
+ pad_low = _pad(low)
+ new_shape = out_type.shape
+ new_shape[axis] += low
+ out_type = ir.VectorType.get(
+ new_shape,
+ out_type.element_type,
+ )
+ operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis)
+
+ if high != 0:
+ pad_high = _pad(high)
+ new_shape = out_type.shape
+ new_shape[axis] += high
+ out_type = ir.VectorType.get(
+ new_shape,
+ out_type.element_type,
+ )
+ operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis)
+
+ if interior > 0:
+ raise NotImplementedError("Not implemented: interior padding")
+
+ return operand
+
+
+lowering_rules[lax.pad_p] = _pad_lowering_rule
diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py
index 6d30cdb0d4a3..66437839cce2 100644
--- a/jax/_src/pallas/mosaic_gpu/lowering.py
+++ b/jax/_src/pallas/mosaic_gpu/lowering.py
@@ -1473,6 +1473,44 @@ def _scan_lowering_rule(
return for_out
+@register_lowering_rule(lax.while_p)
+def _while_lowering_rule(
+ ctx: LoweringRuleContext,
+ *args,
+ cond_jaxpr,
+ body_jaxpr,
+ cond_nconsts,
+ body_nconsts,
+):
+ # First try to lower via a simpler fori loop, which may optimize better.
+ fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
+ cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
+ )
+ del cond_jaxpr, body_jaxpr
+ if fori_jaxpr is None:
+ raise NotImplementedError(err)
+
+ if fori_jaxpr.constvars:
+ raise NotImplementedError
+
+ lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
+ # Reflect the changes of the pattern matcher to the context.
+ avals_in = (
+ *ctx.avals_in[cond_nconsts:body_nconsts],
+ ctx.avals_in[body_nconsts], # the index
+ *ctx.avals_in[body_nconsts + 2:],
+ )
+
+ avals_out = tuple(ctx.avals_out[2:])
+ ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out)
+ _, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts])
+
+ lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype)
+ length = arith_dialect.subi(ub, lb)
+
+ for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True)
+ return (ub, ub, *for_out)
+
@register_lowering_rule(lax.cond_p)
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
index_aval, *_arg_avals = ctx.avals_in
@@ -1501,6 +1539,31 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
return list(switch_op.results)
+@register_lowering_rule(lax.bitcast_convert_type_p)
+def _bitcast_convert_type_lowering_rule(
+ ctx: LoweringRuleContext, operand, *, new_dtype
+):
+ # TODO(petebu) Handle case where src and dst types have different bitwidths
+ [operand_aval] = ctx.avals_in
+ operand = _ensure_fa(operand, operand_aval.dtype)
+ src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype)
+ dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype)
+ assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType))
+ assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType))
+ if src_elem_type.width != dst_elem_type.width:
+ raise NotImplementedError(
+ f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they"
+ " have different widths"
+ )
+ if ir.IntegerType.isinstance(dst_elem_type):
+ output_is_signed = mgpu_utils.is_signed(new_dtype)
+ else:
+ output_is_signed = None
+ return mgpu.FragmentedArray.bitcast(
+ operand, dst_elem_type, output_is_signed=output_is_signed
+ )
+
+
def _bcast(
x: ir.Value,
y: ir.Value,
diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py
index 91e1e1c45429..9fcca6acdacc 100644
--- a/jax/_src/pallas/mosaic_gpu/pipeline.py
+++ b/jax/_src/pallas/mosaic_gpu/pipeline.py
@@ -46,7 +46,16 @@ class BufferedRef:
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
gmem_ref: pallas_core.AbstractMemoryRef
- smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape]
+ # ``None`` if the ref is pinned to GMEM; otherwise, has shape
+ # [num_slots, *spec.block_shape].
+ smem_ref: pallas_core.AbstractMemoryRef | None
+
+ def get_ref_for_slot(
+ self, slot: int | jax.Array
+ ) -> pallas_core.AbstractMemoryRef:
+ if self.smem_ref is None:
+ return self.gmem_ref
+ return self.smem_ref.at[slot]
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
index_map = self.spec.index_map
@@ -59,14 +68,20 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
)
def copy_in(self, slot, grid_indices, barrier_ref):
+ if not _in_smem(self.spec):
+ return
+ assert self.smem_ref is not None
gmem_slices = self.compute_gmem_slice(grid_indices)
gpu_primitives.copy_gmem_to_smem(
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
self.smem_ref.at[slot],
- barrier=barrier_ref.at[slot],
+ barrier_ref.at[slot],
)
def copy_out(self, slot, grid_indices, predicate=None):
+ if not _in_smem(self.spec):
+ return
+ assert self.smem_ref is not None
gmem_slices = self.compute_gmem_slice(grid_indices)
gpu_primitives.copy_smem_to_gmem(
self.smem_ref.at[slot],
@@ -88,8 +103,8 @@ def _uses_arguments(
def _is_index_invariant(
spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid
) -> bool:
- index_map = spec.index_map
- assert index_map is not None
+ if (index_map := spec.index_map) is None:
+ return True
return not any(_uses_arguments(index_map, len(grid)))
@@ -105,6 +120,10 @@ def _inc_grid_by_1(
return tuple(reversed(next_indices))
+def _in_smem(spec: pallas_core.BlockSpec) -> bool:
+ return spec.memory_space in (None, gpu_core.SMEM)
+
+
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
# start/size are static or dynamic. This leads to pytree structure mismatch
# in the pipeline body. So, we define a different ``Slice`` class below.
@@ -125,26 +144,48 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
def emit_pipeline(
- body,
+ body: Callable[..., None],
*,
grid: pallas_core.StaticGrid,
in_specs: Sequence[pallas_core.BlockSpec] = (),
out_specs: Sequence[pallas_core.BlockSpec] = (),
max_concurrent_steps: int = 1,
+ delay_release: int = 0,
):
- """Creates a function to emit a manual pipeline within a Pallas kernel."""
+ """Creates a function to emit a manual pipeline within a Pallas kernel.
+
+ Args:
+ body: The pipeline body.
+ grid: The grid to use for the pipeline.
+ in_specs: The block specs for the inputs.
+ out_specs: The block specs for the outputs.
+ max_concurrent_steps: The maximum number of sequential stages that are
+ active concurrently. Defaults to 1.
+ delay_release: The number of steps to wait before reusing the input/output
+ references. Defaults to 0, and must be strictly smaller than
+ ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you
+ don't await the WGMMA in the body.
+ """
num_steps = math.prod(grid)
+ if max_concurrent_steps <= delay_release:
+ raise ValueError(
+ "max_concurrent_steps must be greater than delay_release, but"
+ f" {max_concurrent_steps=}, {delay_release=}"
+ )
+
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
- # reduce the size of the allocated buffers below.
+ # reduce the size of the refs allocated in SMEM.
if max_concurrent_steps > num_steps:
max_concurrent_steps = num_steps
+ delay_release = 0 # No need to delay anything.
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
if any(
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
for idx in range(1, len(grid) + 1)
+ if spec.block_shape is not None
):
raise NotImplementedError(
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
@@ -153,14 +194,12 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
in_smem_refs, out_smem_refs = util.split_list(
- map(
- lambda spec, ref: gpu_core.SMEM(
- (max_concurrent_steps, *spec.block_shape), # type: ignore
- ref.dtype,
- ),
- it.chain(in_specs, out_specs),
- gmem_refs,
- ),
+ [
+ gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore
+ if _in_smem(spec)
+ else None
+ for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs)
+ ],
[len(in_specs)],
)
return pl.run_scoped(
@@ -173,7 +212,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
out_smem_refs=out_smem_refs,
barrier_ref=gpu_core.Barrier(
# TODO(slebedev): Change this to arrive only once.
- len(in_specs),
+ sum(map(_in_smem, in_specs)),
num_barriers=max_concurrent_steps,
),
)
@@ -207,12 +246,15 @@ def loop_body(step, carry):
# Wait for the current GMEM->SMEM copy to complete.
gpu_primitives.barrier_wait(barrier_ref.at[slot])
# Wait for the previous output SMEM->GMEM copy to complete.
- gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
+ gpu_primitives.wait_smem_to_gmem(
+ max_concurrent_steps - (1 + delay_release), wait_read_only=True
+ )
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
- body(
- *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs))
- )
+ body(*(
+ bref.get_ref_for_slot(slot)
+ for bref in it.chain(in_brefs, out_brefs)
+ ))
if not all(bref.is_index_invariant for bref in out_brefs):
gpu_primitives.commit_smem()
@@ -243,10 +285,10 @@ def loop_body(step, carry):
predicate=lax.bitwise_or(slices_changed, is_last_step),
)
- fetch_step = step + max_concurrent_steps
+ fetch_step = step + (max_concurrent_steps - delay_release)
fetch_slot = slot # (x + y) % y == x % y
jax.lax.cond(
- fetch_step < num_steps,
+ lax.bitwise_and(fetch_step >= delay_release, fetch_step < num_steps),
lambda: map(
lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref),
in_brefs,
diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py
index 5fc4ed5e7afc..0f25f9808ac1 100644
--- a/jax/_src/pallas/mosaic_gpu/primitives.py
+++ b/jax/_src/pallas/mosaic_gpu/primitives.py
@@ -218,7 +218,6 @@ def _copy_gmem_to_smem_lowering(
def copy_gmem_to_smem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
- *,
barrier: pallas_core.AbstractMemoryRef,
) -> None:
"""Asynchronously copies a GMEM reference to a SMEM reference.
@@ -364,20 +363,30 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None:
@wait_smem_to_gmem_p.def_effectful_abstract_eval
-def _wait_smem_to_gmem_abstract_eval(n):
- del n # Unused.
+def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only):
+ del n, wait_read_only # Unused.
return (), {gpu_core._memory_effect}
@lowering.register_lowering_rule(wait_smem_to_gmem_p)
-def _wait_smem_to_gmem_lowering(ctx: lowering.LoweringRuleContext, n):
- ctx.launch_ctx.await_async_copy(allow_groups=n)
+def _wait_smem_to_gmem_lowering(
+ ctx: lowering.LoweringRuleContext, n, *, wait_read_only
+):
+ ctx.launch_ctx.await_async_copy(
+ allow_groups=n, await_read_only=wait_read_only
+ )
return ()
-def wait_smem_to_gmem(n: int) -> None:
- """Waits until there are no more than ``n`` SMEM->GMEM copies in flight."""
- wait_smem_to_gmem_p.bind(n)
+def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None:
+ """Waits until there are no more than ``n`` SMEM->GMEM copies in flight.
+
+ Args:
+ n: The maximum number of copies in flight to wait for.
+ wait_read_only: If ``True``, wait for the in flight copies to finish
+ reading from SMEM. The writes to GMEM are not waited for.
+ """
+ wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only)
# WGMMA on an accumulator reference
diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py
index f7bd0dd4e4d7..729d0e617a87 100644
--- a/jax/_src/pallas/pallas_call.py
+++ b/jax/_src/pallas/pallas_call.py
@@ -72,10 +72,6 @@
pallas_call_p.multiple_results = True
def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
- if start_idx is None:
- assert is_indexing is None
- return value
- assert is_indexing is not None
start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx)
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
@@ -84,10 +80,6 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
is_indexing):
- if start_idx is None:
- assert is_indexing is None
- return update
- assert is_indexing is not None
start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx)
broadcast_dims = tuple(i for i, b in enumerate(is_indexing)
if not b)
@@ -234,8 +226,7 @@ def _pallas_call_impl_interpret(
for bm in grid_mapping.block_mappings
]
block_shapes = [
- None if iid is None
- else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
+ tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
]
@@ -284,8 +275,9 @@ def body(carry):
aval = jax_core.get_aval(s)
s.aval = aval.update(dtype=jnp.int32)
start_indices = [
- None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
- for bm in grid_mapping.block_mappings]
+ bm.compute_start_indices_interpret(loop_idx, *scalars)
+ for bm in grid_mapping.block_mappings
+ ]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes,
carry_consts_ins, is_indexing_dim)
with pallas_core.grid_env(local_grid_env):
diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index f1844c7ba13b..4f16e0013f25 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -16,6 +16,7 @@
from collections import defaultdict
from collections.abc import Callable, Sequence, Iterable
+import contextlib
import dataclasses
from functools import partial
import inspect
@@ -184,16 +185,19 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
args_flat = [*init_states, *args_flat]
try:
- if (core.trace_state_clean() and
- not config.debug_key_reuse.value and
- not config.data_dependent_tracing_fallback.value):
- args_flat = map(core.full_lower, args_flat)
- core.check_eval_args(args_flat)
- out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
- else:
- out_flat = pjit_p.bind(*args_flat, **p.params)
- compiled = None
- profiler = None
+ # TODO(yashkatariya): Maybe thread this into pjit params like resource_env
+ # and set the context manager down the stack?
+ with p.abstract_mesh:
+ if (core.trace_state_clean() and
+ not config.debug_key_reuse.value and
+ not config.data_dependent_tracing_fallback.value):
+ args_flat = map(core.full_lower, args_flat)
+ core.check_eval_args(args_flat)
+ out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
+ else:
+ out_flat = pjit_p.bind(*args_flat, **p.params)
+ compiled = None
+ profiler = None
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
@@ -329,9 +333,10 @@ def cache_miss(*args, **kwargs):
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
"`jit`, but 'no_tracing' is set")
- outs, out_flat, out_tree, args_flat, jaxpr, \
- attrs_tracked, executable, pgle_profiler = _python_pjit_helper(
- fun, jit_info, *args, **kwargs)
+
+ (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable,
+ pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
+
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
jaxpr.consts, jit_info.abstracted_axes,
@@ -494,10 +499,10 @@ def trace(*args, **kwargs) -> stages.Traced:
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
- pgle_profiler=None)
+ pgle_profiler=None)
return stages.Traced(
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
- lower_callable, args_flat, p.arg_names, p.num_consts)
+ lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
wrapped = _cpp_pjit(fun, jit_info)
wrapped.lower = lower
@@ -533,6 +538,7 @@ class PjitParams(NamedTuple):
arg_names: tuple[str, ...] | None
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
+ abstract_mesh: AbstractMesh
def _infer_params_impl(
@@ -637,10 +643,15 @@ def _infer_params_impl(
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
attr_token = _attr_token(flat_fun, in_type)
- jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
- flat_fun, in_type, attr_token, dbg,
- HashableFunction(res_paths, closure=()),
- IgnoreKey(ji.inline))
+
+ abstract_mesh = (
+ get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None
+ else mesh_lib.mesh_context.mesh)
+ with abstract_mesh:
+ jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
+ flat_fun, in_type, attr_token, dbg,
+ HashableFunction(res_paths, closure=()),
+ IgnoreKey(ji.inline))
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
@@ -680,7 +691,27 @@ def _infer_params_impl(
)
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
donated_invars, dbg.arg_names if dbg else None, len(consts),
- attrs_tracked), args_flat
+ attrs_tracked, abstract_mesh), args_flat
+
+
+def get_abstract_mesh(in_avals):
+ if not config.sharding_in_types.value:
+ return contextlib.nullcontext()
+ m = None
+ for a in in_avals:
+ # TODO(yashkatariya): Remove this when mesh context can be set by the user.
+ if a.sharding is None: # type: ignore
+ continue
+ if m is not None and m != a.sharding.mesh:
+ raise ValueError(
+ f'Mesh for all inputs should be equal. Got one mesh: {m} and'
+ f' another mesh: {a.sharding.mesh}')
+ m = a.sharding.mesh # type: ignore
+ # TODO(yashkatariya): Remove this when mesh context can be set by the user.
+ if m is None:
+ return contextlib.nullcontext()
+ assert m is not None
+ return m
class InferParamsCacheEntry:
diff --git a/jax/_src/prng.py b/jax/_src/prng.py
index d2df5d8bbace..2256e12da1d4 100644
--- a/jax/_src/prng.py
+++ b/jax/_src/prng.py
@@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)
-_threefry2x32_lowering_rule = mlir.lower_fun(
+# Since the unrolled lowering is large, emit it as an out-of-line function.
+_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
- multiple_results=True)
+ multiple_results=True))
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py
index 9859eb64cda2..6bbcdd08471f 100644
--- a/jax/_src/public_test_util.py
+++ b/jax/_src/public_test_util.py
@@ -90,6 +90,14 @@ def default_tolerance():
np.dtype(np.complex128): 1e-5,
}
+# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
+if _dtypes.float8_e3m4 is not None:
+ _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
+ default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
+if _dtypes.float8_e4m3 is not None:
+ _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
+ default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
+
def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
_dtypes.float8_e5m2fnuz,
_dtypes.bfloat16,
]
+
+ if _dtypes.float8_e4m3 is not None:
+ custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
+ if _dtypes.float8_e3m4 is not None:
+ custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
+
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)
diff --git a/jax/_src/random.py b/jax/_src/random.py
index dc9fc18aff38..4313d9036eda 100644
--- a/jax/_src/random.py
+++ b/jax/_src/random.py
@@ -55,8 +55,6 @@
Shape = Sequence[int]
PRNGImpl = prng.PRNGImpl
-KeyArray = Array
-KeyArrayLike = ArrayLike
UINT_DTYPES = prng.UINT_DTYPES
@@ -69,8 +67,8 @@ def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)
-def _check_prng_key(name: str, key: KeyArrayLike, *,
- allow_batched: bool = False) -> tuple[KeyArray, bool]:
+def _check_prng_key(name: str, key: ArrayLike, *,
+ allow_batched: bool = False) -> tuple[Array, bool]:
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
wrapped_key = key
wrapped = False
@@ -113,7 +111,7 @@ def _return_prng_keys(was_wrapped, key):
return prng.random_unwrap(key) if was_wrapped else key
-def _random_bits(key: KeyArray, bit_width: int, shape: Shape) -> Array:
+def _random_bits(key: Array, bit_width: int, shape: Shape) -> Array:
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
return prng.random_bits(key, bit_width=bit_width, shape=shape)
@@ -188,7 +186,7 @@ def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl:
def _key(ctor_name: str, seed: int | ArrayLike,
- impl_spec: PRNGSpecDesc | None) -> KeyArray:
+ impl_spec: PRNGSpecDesc | None) -> Array:
impl = resolve_prng_impl(impl_spec)
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
@@ -200,7 +198,7 @@ def _key(ctor_name: str, seed: int | ArrayLike,
return prng.random_seed(seed, impl=impl)
def key(seed: int | ArrayLike, *,
- impl: PRNGSpecDesc | None = None) -> KeyArray:
+ impl: PRNGSpecDesc | None = None) -> Array:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array containing a key, whose dtype indicates
@@ -220,7 +218,7 @@ def key(seed: int | ArrayLike, *,
return _key('key', seed, impl)
def PRNGKey(seed: int | ArrayLike, *,
- impl: PRNGSpecDesc | None = None) -> KeyArray:
+ impl: PRNGSpecDesc | None = None) -> Array:
"""Create a legacy PRNG key given an integer seed.
This function produces old-style legacy PRNG keys, which are arrays
@@ -248,7 +246,7 @@ def PRNGKey(seed: int | ArrayLike, *,
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
-def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
+def fold_in(key: ArrayLike, data: IntegerArray) -> Array:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
@@ -267,7 +265,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
return _return_prng_keys(wrapped, key_out)
-def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
+def _split(key: Array, num: int | tuple[int, ...] = 2) -> Array:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
@@ -278,7 +276,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray:
shape = tuple(num) if isinstance(num, Sequence) else (num,)
return prng.random_split(key, shape=shape)
-def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
+def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
@@ -293,21 +291,22 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
return _return_prng_keys(wrapped, _split(typed_key, num))
-def _key_impl(keys: KeyArray) -> PRNGImpl:
+def _key_impl(keys: Array) -> str | PRNGSpec:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
- return keys_dtype._impl
+ impl = keys_dtype._impl
+ return impl.name if impl.name in prng.prngs else PRNGSpec(impl)
-def key_impl(keys: KeyArrayLike) -> PRNGSpec:
+def key_impl(keys: ArrayLike) -> str | PRNGSpec:
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
- return PRNGSpec(_key_impl(typed_keys))
+ return _key_impl(typed_keys)
-def _key_data(keys: KeyArray) -> Array:
+def _key_data(keys: Array) -> Array:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
return prng.random_unwrap(keys)
-def key_data(keys: KeyArrayLike) -> Array:
+def key_data(keys: ArrayLike) -> Array:
"""Recover the bits of key data underlying a PRNG key array."""
keys, _ = _check_prng_key("key_data", keys, allow_batched=True)
return _key_data(keys)
@@ -344,7 +343,7 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
raise ValueError(msg.format(name, shape_, shape))
-def bits(key: KeyArrayLike,
+def bits(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeUInt | None = None) -> Array:
"""Sample uniform bits in the form of unsigned integers.
@@ -373,7 +372,7 @@ def bits(key: KeyArrayLike,
return _random_bits(key, bit_width, shape)
-def uniform(key: KeyArrayLike,
+def uniform(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float,
minval: RealArray = 0.,
@@ -443,7 +442,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
lax.reshape(floats * (maxval - minval) + minval, shape))
-def randint(key: KeyArrayLike,
+def randint(key: ArrayLike,
shape: Shape,
minval: IntegerArray,
maxval: IntegerArray,
@@ -532,7 +531,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
-def permutation(key: KeyArrayLike,
+def permutation(key: ArrayLike,
x: int | ArrayLike,
axis: int = 0,
independent: bool = False) -> Array:
@@ -595,7 +594,7 @@ def _shuffle(key, x, axis) -> Array:
return x
-def choice(key: KeyArrayLike,
+def choice(key: ArrayLike,
a: int | ArrayLike,
shape: Shape = (),
replace: bool = True,
@@ -676,7 +675,7 @@ def choice(key: KeyArrayLike,
arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:])
-def normal(key: KeyArrayLike,
+def normal(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample standard normal random values with given shape and float dtype.
@@ -729,7 +728,7 @@ def _normal_real(key, shape, dtype) -> Array:
return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u))
-def multivariate_normal(key: KeyArrayLike,
+def multivariate_normal(key: ArrayLike,
mean: RealArray,
cov: RealArray,
shape: Shape | None = None,
@@ -812,7 +811,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
return result
-def truncated_normal(key: KeyArrayLike,
+def truncated_normal(key: ArrayLike,
lower: RealArray,
upper: RealArray,
shape: Shape | None = None,
@@ -878,7 +877,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
-def bernoulli(key: KeyArrayLike,
+def bernoulli(key: ArrayLike,
p: RealArray = np.float32(0.5),
shape: Shape | None = None) -> Array:
r"""Sample Bernoulli random values with given shape and mean.
@@ -923,7 +922,7 @@ def _bernoulli(key, p, shape) -> Array:
return uniform(key, shape, lax.dtype(p)) < p
-def beta(key: KeyArrayLike,
+def beta(key: ArrayLike,
a: RealArray,
b: RealArray,
shape: Shape | None = None,
@@ -984,7 +983,7 @@ def _beta(key, a, b, shape, dtype) -> Array:
return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled)
-def cauchy(key: KeyArrayLike,
+def cauchy(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Cauchy random values with given shape and float dtype.
@@ -1023,7 +1022,7 @@ def _cauchy(key, shape, dtype) -> Array:
return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5))))
-def dirichlet(key: KeyArrayLike,
+def dirichlet(key: ArrayLike,
alpha: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -1095,7 +1094,7 @@ def _softmax(x, axis) -> Array:
return unnormalized / unnormalized.sum(axis, keepdims=True)
-def exponential(key: KeyArrayLike,
+def exponential(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Exponential random values with given shape and float dtype.
@@ -1134,7 +1133,7 @@ def _exponential(key, shape, dtype) -> Array:
return lax.neg(lax.log1p(lax.neg(u)))
-def _gamma_one(key: KeyArray, alpha, log_space) -> Array:
+def _gamma_one(key: Array, alpha, log_space) -> Array:
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
# The algorithm can also be founded in:
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
@@ -1262,7 +1261,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
multiple_results=False), platform='cpu')
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
-def gamma(key: KeyArrayLike,
+def gamma(key: ArrayLike,
a: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -1309,7 +1308,7 @@ def gamma(key: KeyArrayLike,
return _gamma(key, a, shape=shape, dtype=dtype)
-def loggamma(key: KeyArrayLike,
+def loggamma(key: ArrayLike,
a: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -1451,7 +1450,7 @@ def _poisson(key, lam, shape, dtype) -> Array:
return lax.select(lam == 0, jnp.zeros_like(result), result)
-def poisson(key: KeyArrayLike,
+def poisson(key: ArrayLike,
lam: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeInt = int) -> Array:
@@ -1496,7 +1495,7 @@ def poisson(key: KeyArrayLike,
return _poisson(key, lam, shape, dtype)
-def gumbel(key: KeyArrayLike,
+def gumbel(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
"""Sample Gumbel random values with given shape and float dtype.
@@ -1532,7 +1531,7 @@ def _gumbel(key, shape, dtype) -> Array:
uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
-def categorical(key: KeyArrayLike,
+def categorical(key: ArrayLike,
logits: RealArray,
axis: int = -1,
shape: Shape | None = None) -> Array:
@@ -1574,7 +1573,7 @@ def categorical(key: KeyArrayLike,
axis=axis)
-def laplace(key: KeyArrayLike,
+def laplace(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Laplace random values with given shape and float dtype.
@@ -1611,7 +1610,7 @@ def _laplace(key, shape, dtype) -> Array:
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
-def logistic(key: KeyArrayLike,
+def logistic(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample logistic random values with given shape and float dtype.
@@ -1647,7 +1646,7 @@ def _logistic(key, shape, dtype):
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
-def pareto(key: KeyArrayLike,
+def pareto(key: ArrayLike,
b: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -1696,7 +1695,7 @@ def _pareto(key, b, shape, dtype) -> Array:
return lax.exp(e / b)
-def t(key: KeyArrayLike,
+def t(key: ArrayLike,
df: RealArray,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
@@ -1748,7 +1747,7 @@ def _t(key, df, shape, dtype) -> Array:
return n * jnp.sqrt(half_df / g)
-def chisquare(key: KeyArrayLike,
+def chisquare(key: ArrayLike,
df: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -1800,7 +1799,7 @@ def _chisquare(key, df, shape, dtype) -> Array:
return chi2
-def f(key: KeyArrayLike,
+def f(key: ArrayLike,
dfnum: RealArray,
dfden: RealArray,
shape: Shape | None = None,
@@ -1864,7 +1863,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array:
return f
-def rademacher(key: KeyArrayLike,
+def rademacher(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeInt = int) -> Array:
r"""Sample from a Rademacher distribution.
@@ -1899,7 +1898,7 @@ def _rademacher(key, shape, dtype) -> Array:
return (2 * bernoulli_samples - 1).astype(dtype)
-def maxwell(key: KeyArrayLike,
+def maxwell(key: ArrayLike,
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample from a one sided Maxwell distribution.
@@ -1939,7 +1938,7 @@ def _maxwell(key, shape, dtype) -> Array:
return jnp.linalg.norm(norm_rvs, axis=-1)
-def double_sided_maxwell(key: KeyArrayLike,
+def double_sided_maxwell(key: ArrayLike,
loc: RealArray,
scale: RealArray,
shape: Shape = (),
@@ -1991,7 +1990,7 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array:
return random_sign * maxwell_rvs * scale + loc
-def weibull_min(key: KeyArrayLike,
+def weibull_min(key: ArrayLike,
scale: RealArray,
concentration: RealArray,
shape: Shape = (),
@@ -2037,7 +2036,7 @@ def _weibull_min(key, scale, concentration, shape, dtype) -> Array:
def orthogonal(
- key: KeyArrayLike,
+ key: ArrayLike,
n: int,
shape: Shape = (),
dtype: DTypeLikeFloat = float
@@ -2072,7 +2071,7 @@ def orthogonal(
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
def generalized_normal(
- key: KeyArrayLike,
+ key: ArrayLike,
p: float,
shape: Shape = (),
dtype: DTypeLikeFloat = float
@@ -2107,7 +2106,7 @@ def generalized_normal(
return r * g ** (1 / p)
def ball(
- key: KeyArrayLike,
+ key: ArrayLike,
d: int,
p: float = 2,
shape: Shape = (),
@@ -2139,7 +2138,7 @@ def ball(
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]
-def rayleigh(key: KeyArrayLike,
+def rayleigh(key: ArrayLike,
scale: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -2192,7 +2191,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array:
ray = lax.mul(scale, sqrt_u)
return ray
-def wald(key: KeyArrayLike,
+def wald(key: ArrayLike,
mean: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -2250,7 +2249,7 @@ def _wald(key, mean, shape, dtype) -> Array:
w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x)
return w
-def geometric(key: KeyArrayLike,
+def geometric(key: ArrayLike,
p: RealArray,
shape: Shape | None = None,
dtype: DTypeLikeInt = int) -> Array:
@@ -2303,7 +2302,7 @@ def _geometric(key, p, shape, dtype) -> Array:
return g.astype(dtype)
-def triangular(key: KeyArrayLike,
+def triangular(key: ArrayLike,
left: RealArray,
mode: RealArray,
right: RealArray,
@@ -2367,7 +2366,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array:
return tri
-def lognormal(key: KeyArrayLike,
+def lognormal(key: ArrayLike,
sigma: RealArray = np.float32(1),
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
@@ -2572,7 +2571,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array:
def binomial(
- key: KeyArray,
+ key: Array,
n: RealArray,
p: RealArray,
shape: Shape | None = None,
diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py
index 605cde19b1e7..2fffe6381b97 100644
--- a/jax/_src/scipy/special.py
+++ b/jax/_src/scipy/special.py
@@ -66,6 +66,7 @@ def gammaln(x: ArrayLike) -> Array:
return lax.lgamma(x)
+@jit
def gammasgn(x: ArrayLike) -> Array:
r"""Sign of the gamma function.
@@ -81,6 +82,13 @@ def gammasgn(x: ArrayLike) -> Array:
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
Because :math:`\Gamma(x)` is never zero, no condition is required for this case.
+ * if :math:`x = -\infty`, NaN is returned.
+ * if :math:`x = \pm 0`, :math:`\pm 1` is returned.
+ * if :math:`x` is a negative integer, NaN is returned. The sign of gamma
+ at a negative integer depends on from which side the pole is approached.
+ * if :math:`x = \infty`, :math:`1` is returned.
+ * if :math:`x` is NaN, NaN is returned.
+
Args:
x: arraylike, real valued.
@@ -92,8 +100,14 @@ def gammasgn(x: ArrayLike) -> Array:
- :func:`jax.scipy.special.gammaln`: the natural log of the gamma function
"""
x, = promote_args_inexact("gammasgn", x)
+ typ = x.dtype.type
floor_x = lax.floor(x)
- return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0)
+ x_negative = x < 0
+ return jnp.select(
+ [(x_negative & (x == floor_x)) | jnp.isnan(x),
+ (x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))],
+ [typ(np.nan), typ(-1.0)],
+ typ(1.0))
def gamma(x: ArrayLike) -> Array:
@@ -115,6 +129,13 @@ def gamma(x: ArrayLike) -> Array:
\Gamma(n) = (n - 1)!
+ * if :math:`z = -\infty`, NaN is returned.
+ * if :math:`x = \pm 0`, :math:`\pm \infty` is returned.
+ * if :math:`x` is a negative integer, NaN is returned. The sign of gamma
+ at a negative integer depends on from which side the pole is approached.
+ * if :math:`x = \infty`, :math:`\infty` is returned.
+ * if :math:`x` is NaN, NaN is returned.
+
Args:
x: arraylike, real valued.
@@ -127,7 +148,8 @@ def gamma(x: ArrayLike) -> Array:
- :func:`jax.scipy.special.gammasgn`: the sign of the gamma function
Notes:
- Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs.
+ Unlike the scipy version, JAX's ``gamma`` does not support complex-valued
+ inputs.
"""
x, = promote_args_inexact("gamma", x)
return gammasgn(x) * lax.exp(lax.lgamma(x))
diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py
index f410d08e4f3d..4343c080251c 100644
--- a/jax/_src/scipy/stats/gamma.py
+++ b/jax/_src/scipy/stats/gamma.py
@@ -51,12 +51,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
+ ok = lax.ge(x, loc)
one = _lax_const(x, 1)
- y = lax.div(lax.sub(x, loc), scale)
+ y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one)
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
shape_terms = lax.add(gammaln(a), lax.log(scale))
log_probs = lax.sub(log_linear_term, shape_terms)
- return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
+ return jnp.where(ok, log_probs, -jnp.inf)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py
index 9b847f15d86a..8abe58e52a74 100644
--- a/jax/_src/sharding_impls.py
+++ b/jax/_src/sharding_impls.py
@@ -137,9 +137,12 @@ def named_sharding_to_xla_hlo_sharding(
mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}
special_axes = {}
- if self._manual_axes:
+ mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
+ if t == mesh_lib.AxisTypes.Collective}
+ manual_axes = self._manual_axes.union(mesh_manual_axes)
+ if manual_axes:
axis_names = self.mesh.axis_names
- for manual_axis in self._manual_axes:
+ for manual_axis in manual_axes:
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
replicated_mesh_axes = []
@@ -360,8 +363,10 @@ def is_fully_replicated(self) -> bool:
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
- def _normalized_spec(self, ndim: int) -> PartitionSpec:
- return self.spec._normalized_spec(ndim)
+ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
+ if not isinstance(spec, PartitionSpec):
+ spec = PartitionSpec(*spec)
+ return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
diff --git a/jax/_src/stages.py b/jax/_src/stages.py
index 92c680009c93..b6f3b63d3de4 100644
--- a/jax/_src/stages.py
+++ b/jax/_src/stages.py
@@ -30,6 +30,7 @@
"""
from __future__ import annotations
+import contextlib
import functools
from collections.abc import Sequence
from dataclasses import dataclass
@@ -716,13 +717,14 @@ class Traced(Stage):
"_args_flat", "_arg_names", "_num_consts"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
- lower_callable, args_flat=None, arg_names=None,
- num_consts: int = 0):
+ lower_callable, abstract_mesh=contextlib.nullcontext(),
+ args_flat=None, arg_names=None, num_consts: int = 0):
self.jaxpr = jaxpr
self.args_info = args_info
self.fun_name = fun_name
self._out_tree = out_tree
self._lower_callable = lower_callable
+ self._abstract_mesh = abstract_mesh
self._args_flat = args_flat
self._arg_names = arg_names
self._num_consts = num_consts
@@ -743,7 +745,10 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
self._lower_callable, lowering_platforms=lowering_platforms,
lowering_parameters=_private_parameters)
try:
- lowering = new_callable()
+ # TODO(yashkatariya): Maybe thread this into pjit params like resource_env
+ # and set the context manager down the stack?
+ with self._abstract_mesh:
+ lowering = new_callable()
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
msg = pjit._device_assignment_mismatch_error(
diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py
index 538f3f8e4888..2da93e3d8e80 100644
--- a/jax/_src/state/indexing.py
+++ b/jax/_src/state/indexing.py
@@ -46,11 +46,11 @@ def __post_init__(self):
@property
def is_dynamic_start(self):
- return not isinstance(self.start, int)
+ return not core.is_dim(self.start)
@property
def is_dynamic_size(self):
- return not isinstance(self.size, int)
+ return not core.is_dim(self.size)
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
@@ -72,10 +72,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
- start, stop, step = slc.indices(size)
+ start, step, size = core.canonicalize_slice(slc, size)
if step < 1:
raise ValueError(f"slice must have a step >= 1 (found: {step})")
- return cls(start, max((stop - start + step - 1) // step, 0), step)
+ return cls(start, size, step)
def dslice(
diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py
index 0897e778d079..14d42ad0809c 100644
--- a/jax/_src/state/primitives.py
+++ b/jax/_src/state/primitives.py
@@ -214,7 +214,10 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
if isinstance(ref_aval.inner_aval, core.ShapedArray):
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
- out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype)
+ # TODO(yashkatariya): Transform the sharding too instead of setting it to
+ # None.
+ out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
+ sharding=None)
else:
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index 78de511d4ec4..c5a713743fb8 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -44,6 +44,7 @@
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
+from jax._src import lib as _jaxlib
from jax._src import linear_util as lu
from jax._src import monitoring
from jax._src import pjit as pjit_lib
@@ -451,13 +452,25 @@ def assert_num_jit_and_pmap_compilations(times):
f"but executed {count[0]}")
+def jaxlib_version() -> tuple[int, ...]:
+ return _jaxlib.version
+
+
def device_under_test():
return _TEST_DUT.value or xla_bridge.get_backend().platform
def supported_dtypes():
if device_under_test() == "tpu":
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
- np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
+ np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64,
+ _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz,
+ _dtypes.float8_e5m2}
+ elif device_under_test() == "gpu":
+ types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
+ np.uint8, np.uint16, np.uint32, np.uint64,
+ _dtypes.bfloat16, np.float16, np.float32, np.float64,
+ np.complex64, np.complex128, _dtypes.float8_e4m3fn,
+ _dtypes.float8_e5m2}
elif device_under_test() == "METAL":
types = {np.int32, np.uint32, np.float32}
else:
@@ -965,6 +978,31 @@ def fn(shape, dtype):
size=shape, replace=False)
return fn
+def rand_indices_unique_along_axis(rng):
+ """Sample an array of given shape containing indices up to dim (exclusive),
+ such that the indices are unique along the given axis.
+ Optionally, convert some of the resulting indices to negative indices."""
+ def fn(dim, shape, axis, allow_negative=True):
+ batch_size = math.prod(shape[:axis] + shape[axis:][1:])
+ idx = [
+ rng.choice(dim, size=shape[axis], replace=False)
+ for _ in range(batch_size)
+ ]
+ idx = np.array(idx).reshape(batch_size, shape[axis])
+ idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],))
+ idx = np.moveaxis(idx, -1, axis)
+
+ # assert that indices are unique along the given axis
+ count = partial(np.bincount, minlength=dim)
+ assert (np.apply_along_axis(count, axis, idx) <= 1).all()
+
+ if allow_negative:
+ mask = rng.choice([False, True], idx.shape)
+ idx[mask] -= dim
+ return idx
+
+ return fn
+
def rand_bool(rng):
def generator(shape, dtype):
return _cast_to_shape(
@@ -1439,10 +1477,19 @@ def supported(self, dtypes):
@_cached_property
def custom_floats(self):
- return [np.dtype(t) for t in [
- _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
- _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz,
- _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]]
+ float_dtypes = [
+ _dtypes.bfloat16,
+ _dtypes.float8_e4m3b11fnuz,
+ _dtypes.float8_e4m3fn,
+ _dtypes.float8_e4m3fnuz,
+ _dtypes.float8_e5m2,
+ _dtypes.float8_e5m2fnuz,
+ ]
+ if _dtypes.float8_e3m4 is not None:
+ float_dtypes += [_dtypes.float8_e3m4]
+ if _dtypes.float8_e4m3 is not None:
+ float_dtypes += [_dtypes.float8_e4m3]
+ return self.supported(float_dtypes)
@_cached_property
def floating(self):
diff --git a/jax/_src/util.py b/jax/_src/util.py
index fce342c493ed..8dcc5eaa5804 100644
--- a/jax/_src/util.py
+++ b/jax/_src/util.py
@@ -453,6 +453,10 @@ def tuple_update(t, idx, val):
assert 0 <= idx < len(t), (idx, len(t))
return t[:idx] + (val,) + t[idx+1:]
+def tuple_replace(tupl, index, item):
+ # unlike tuple_update, works with negative indices as well
+ return tupl[:index] + (item,) + tupl[index:][1:]
+
class HashableFunction:
"""Decouples function equality and hash from its identity.
diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py
index fd989d052917..e1ee37f3d24d 100644
--- a/jax/experimental/mosaic/gpu/fragmented_array.py
+++ b/jax/experimental/mosaic/gpu/fragmented_array.py
@@ -623,9 +623,37 @@ def to_layout(self, new_layout: FragmentedLayout):
)
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
- is_signed = (
- output_is_signed if output_is_signed is not None else self.is_signed
- )
+ if isinstance(self.layout, WGSplatFragLayout):
+ # Find either the largest operand or an operand that has a
+ # concrete layout base the layout computation of that.
+ widest_idx = None
+ for i, o in enumerate(other):
+ if not isinstance(o, FragmentedArray):
+ continue
+ elif not isinstance(o.layout, WGSplatFragLayout):
+ widest_idx = i
+ break
+ elif not o.layout.can_broadcast_to(self.layout.shape):
+ # Note: equal shapes can be broadcast to each other. Using
+ # the negation we make sure to only consider strictly larger
+ # shapes so that we don't end up ping ponging between equal
+ # shapes.
+ widest_idx = i
+
+ if widest_idx is not None:
+ # We need to retain the order of arguments that the op
+ # expects.
+ def _op(wide_o, self_o, *args):
+ pre_wide = args[:widest_idx - 1]
+ post_wide = args[widest_idx - 1:]
+ return op(self_o, *pre_wide, wide_o, *post_wide)
+ return other[widest_idx]._pointwise(
+ _op,
+ self,
+ *other[:widest_idx],
+ *other[widest_idx + 1:],
+ output_is_signed=output_is_signed,
+ )
other_arrs = []
for o in other:
@@ -636,7 +664,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
raise NotImplementedError(o)
o = FragmentedArray.splat(
- o, shape=self.shape, layout=self.layout, is_signed=is_signed
+ o, shape=self.shape, layout=self.layout, is_signed=self.is_signed
)
if isinstance(o.layout, WGSplatFragLayout):
@@ -646,7 +674,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
o.registers.flat[0],
shape=self.shape,
layout=self.layout,
- is_signed=is_signed,
+ is_signed=o.is_signed,
)
else:
if self.layout != o.layout:
@@ -659,8 +687,13 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None):
for idx, reg in np.ndenumerate(self.registers):
new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs))
+ reg_ty = new_regs.flat[0].type
+ if ir.VectorType.isinstance(reg_ty):
+ reg_ty = ir.VectorType(reg_ty).element_type
+ if output_is_signed is None and ir.IntegerType.isinstance(reg_ty):
+ output_is_signed = self.is_signed
return FragmentedArray(
- _registers=new_regs, _layout=self.layout, _is_signed=is_signed
+ _registers=new_regs, _layout=self.layout, _is_signed=output_is_signed
)
def __pos__(self):
@@ -928,7 +961,9 @@ def fast_instr(x):
raise NotImplementedError(x.type)
return fast_instr
- def bitcast(self, elt: ir.Type):
+ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
+ if elt == self.mlir_dtype:
+ return self
reg_type = self.registers.flat[0].type
if ir.VectorType.isinstance(reg_type):
reg_shape = ir.VectorType(reg_type).shape
@@ -936,7 +971,9 @@ def bitcast(self, elt: ir.Type):
else:
ty = elt
- return self._pointwise(lambda x: arith.bitcast(ty, x))
+ return self._pointwise(
+ lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed
+ )
def __getitem__(self, idx):
if self.layout != WGMMA_LAYOUT:
diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py
index 337581c54b86..0594e9239be7 100644
--- a/jax/experimental/mosaic/gpu/profiler.py
+++ b/jax/experimental/mosaic/gpu/profiler.py
@@ -36,12 +36,15 @@
try:
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
except ImportError:
- pass
+ has_registrations = False
else:
- for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations():
- xla_client.register_custom_call_target(
- name, handler, platform="CUDA", api_version=1
- )
+ # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36.
+ has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations")
+ if has_registrations:
+ for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations():
+ xla_client.register_custom_call_target(
+ name, handler, platform="CUDA", api_version=1
+ )
# ruff: noqa: F405
# mypy: ignore-errors
@@ -80,6 +83,11 @@ def measure(
Returns:
The return value of ``f`` and the elapsed time in milliseconds.
"""
+ if not has_registrations:
+ raise RuntimeError(
+ "This function requires jaxlib >=0.4.36 with CUDA support."
+ )
+
if not (args or kwargs):
# We require at least one argument and at least one output to ensure
# that there is a data dependency between `_event_record` calls in
diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py
index b716456eceb3..0ce1140cfa07 100644
--- a/jax/experimental/mosaic/gpu/utils.py
+++ b/jax/experimental/mosaic/gpu/utils.py
@@ -296,6 +296,12 @@ def globaltimer(kind: Literal["low", "high"] | None = None):
def bytewidth(ty: ir.Type):
+ # The actual width of TF32 is 19 bits. However, sinc we need to treat it as
+ # 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream
+ # MLIR, but it changed in
+ # https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd.
+ if ir.FloatTF32Type.isinstance(ty):
+ return 4
if ir.IntegerType.isinstance(ty):
return ir.IntegerType(ty).width // 8
if ir.FloatType.isinstance(ty):
diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py
index 34cb5328f36a..7e6527ad999a 100644
--- a/jax/experimental/pallas/__init__.py
+++ b/jax/experimental/pallas/__init__.py
@@ -30,6 +30,7 @@
from jax._src.pallas.core import no_block_spec as no_block_spec
from jax._src.pallas.core import Unblocked as Unblocked
from jax._src.pallas.core import unblocked as unblocked
+from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost
from jax._src.pallas.pallas_call import pallas_call as pallas_call
from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p
from jax._src.pallas.primitives import atomic_add as atomic_add
diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md
new file mode 100644
index 000000000000..40b109d102d5
--- /dev/null
+++ b/jax/experimental/pallas/g3doc/debugging.md
@@ -0,0 +1,207 @@
+# Debugging Pallas
+
+
+
+
+
+[TOC]
+
+This document contains a collection of tips and tricks for debugging Pallas
+programs. For any specific requests or ideas for improvement, please create
+a ticket on https://github.com/jax-ml/jax/issues.
+
+## Debugging Tools
+
+### Interpret (HLO) Mode
+
+Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas.
+
+Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations.
+
+### debug_print
+
+The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation.
+
+For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option.
+
+
+```python
+kernel = pl.pallas_call(...)
+compiled_kernel = (
+ jax.jit(kernel)
+ .lower(x)
+ .compile({'xla_tpu_enable_log_recorder': 'true'})
+ )
+result = compiled_kernel(x)
+```
+
+### Runtime Asserts
+
+Checkify can be used to insert runtime asserts, nan checks, out of bounds errors, etc. inside of a kernel.
+Pallas implements two options for assertions: a *hard assert* which will crash the TPU if failed, and a *functionalized assertion* which will simulate a runtime assertion that can be thrown
+as a Python error after the kernel has successfully executed.
+
+#### Hard assertion
+
+Hard assertions can be inserted with `checkify.check`
+and running your program with the `--jax_pallas_enable_runtime_assert` flag.
+
+Your code will look like the following:
+
+```python
+from jax.experimental import checkify
+
+def kernel(...):
+ checkify.check(x > y, "Check x > y failed") # Will halt if x <= y
+```
+
+This will print a relatively lengthy dump which resembles the following:
+
+```
+E1001 15:22:33.275768 4353 real_program_continuator.cc:1350] 0x0x0_TC0: [Physical location: dldgr4:pe1:1] generic::internal: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x169 (from TensorCoreSequencer:1:0x213): Check x > y failed HLO: main; HLO computation: main.3
+```
+
+The benefit of a hard assertion is that it is guaranteed to either pass or
+halt the TPU. The kernel will never proceed past the assertion if it fails.
+However, the downside is that if the assertion fails you will
+likely have to restart the program in order to run any other TPU operations,
+and there is no Python error thrown that can be caught.
+
+#### Functionalized assertion
+Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op like so:
+
+```python
+from jax.experimental import checkify
+
+def kernel(...):
+ checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y
+
+kernel = pl.pallas_call(...)
+checkified_kernel = checkify.checkify(kernel,
+ errors=checkify.all_checks)
+error, result = checkified_kernel(x)
+error.throw()
+```
+
+This will throw a Python error if any checks failed, such as if a NaN occurred
+or if an out-of-bounds index was accessed.
+
+The benefit of a functionalized assert is that it will throw Python errors
+that can be caught, and it will not interfere with downstream TPU operations.
+However, it requires the kernel to successfully complete, meaning if your
+error would have caused a TPU crash, the crash would still happen and
+the error would not be thrown.
+
+
+### Dumping Jaxprs
+
+Passing in `debug=True` into `pl.pallas_call` will print out the Jaxpr of the kernel as well as the lowered Mosaic code.
+
+```python
+def kernel(x_ref, y_ref, o_ref):
+ o_ref[...] = x_ref[...] + y_ref[...]
+
+x = jnp.ones((8, 128), dtype=jnp.float32)
+pl.pallas_call(
+ kernel,
+ out_shape=jax.ShapeDTypeStruct((8, 128), jnp.float32)
+ debug=True,
+ name="my_call",
+)(x, x)
+```
+
+This will output:
+
+```
+The kernel jaxpr for the pallas_call my_call for kernel function kernel at ...:1000:
+{ lambda ; a:MemRef{float32[8,128]} b:MemRef{float32[8,128]} c:MemRef{float32[8,128]}. let
+ d:f32[8,128] <- a[:,:]
+ e:f32[8,128] <- b[:,:]
+ f:f32[8,128] = add d e
+ c[:,:] <- f
+ in () }
+
+The Mosaic module for the pallas_call my_call for kernel function kernel at ...:1000:
+module {
+ func.func @main(%arg0: memref<8x128xf32, #tpu.memory_space>, %arg1: memref<8x128xf32, #tpu.memory_space>, %arg2: memref<8x128xf32, #tpu.memory_space>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} {
+ %c0 = arith.constant 0 : index
+ %c0_0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0_0] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32>
+ %c0_1 = arith.constant 0 : index
+ %c0_2 = arith.constant 0 : index
+ %1 = vector.load %arg1[%c0_1, %c0_2] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32>
+ %2 = arith.addf %0, %1 : vector<8x128xf32>
+ %c0_3 = arith.constant 0 : index
+ %c0_4 = arith.constant 0 : index
+ %3 = vector.load %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32>
+ vector.store %2, %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32>
+ return
+ }
+}
+```
+
+### Dumping Mosaic Passes
+
+Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosaic if you are running into errors that are originating from the Mosaic compiler to see what code is actually being generated.
+
+Passing the `--xla_mosaic_dump_to=` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge.
+
+### Static Verification
+
+The static verification tool can be used to automatically detect race conditions in distributed kernels.
+Because this tool uses formal verification, it is best used for small kernels (<=2 devices).
+
+Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=`,
+which will output a Promela dump file. Afterwards, the dump file can be
+analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run:
+
+```
+spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan
+```
+
+
+
+## Useful Command line flags
+
+* OOB Checks: `--xla_mosaic_on_device_checks=bounds`
+* Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true`
+
+* Dump Mosaic: `--xla_mosaic_dump_to=`
+* Enable trace markers in XProf: `--xla_enable_transpose_trace`
+
+## Common Errors
+
+### INTERNAL Mosaic failed to compile TPU Kernel
+
+`INTERNAL Mosaic failed to compile TPU Kernel: Not implemented X`
+
+This error means that you hit an unimplemented case in the underlying Mosaic compiler.
+Our recommended course of action here is to file a ticket if one does not already
+exist for your specific error.
+
+In some cases, your error may be due to an operation which cannot be implemented
+efficiently in the compiler, in which your best course of action is to find a workaround. This
+is most commonly seen in `layout` and `shape_cast` errors. The important tip
+to remember regarding layouts is that the last 2 dimensions of arrays in Pallas
+are physically tiled into registers, so any reshapes, slicing, transposes, etc.
+on the last 2 dimensions may trigger a relayout.
+
+
+### VerificationError
+
+A verification error indicates that Pallas produced invalid code for Mosaic.
+
+This is a bug in Pallas, so please file a bug under https://github.com/jax-ml/jax/issues.
+
+### LoweringError
+
+This is a catch-all error type during Pallas to Mosaic lowering and can have many causes.
+In most cases the error message should hint at what is wrong.
+
+For specific errors:
+
+* `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod
+
+
diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py
index 56db5379d5e2..1c5b4d9f741b 100644
--- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py
+++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py
@@ -89,7 +89,7 @@ def _compute_wg():
plgpu.copy_gmem_to_smem(
q_ref.at[pl.ds(q_seq_base, block_q), q_head],
qo_smem,
- barrier=q_barriers.at[wg_idx],
+ q_barriers.at[wg_idx],
)
plgpu.barrier_wait(q_barriers.at[wg_idx])
@@ -166,17 +166,17 @@ def _memory_wg():
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
for i in range(max_concurrent_steps):
s = (pl.ds(i * block_kv, block_kv), kv_head)
- plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], barrier=k_barriers.at[i])
- plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], barrier=v_barriers.at[i])
+ plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
+ plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i])
def kv_loop(kv_step, _):
tma_step = kv_step + max_concurrent_steps
tma_slot = lax.rem(kv_step, max_concurrent_steps)
s = (pl.ds(tma_step * block_kv, block_kv), kv_head)
plgpu.barrier_wait(k_consumed_barrier)
- plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], barrier=k_barriers.at[tma_slot])
+ plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot])
plgpu.barrier_wait(v_consumed_barrier)
- plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], barrier=v_barriers.at[tma_slot])
+ plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot])
lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None)
def kv_epilogue(i, _):
diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py
index 9b122fcc03ef..0cb3d798d09e 100644
--- a/jax/experimental/pallas/ops/tpu/flash_attention.py
+++ b/jax/experimental/pallas/ops/tpu/flash_attention.py
@@ -574,26 +574,23 @@ def _fwd_cost_estimate(
q: jax.Array,
k: jax.Array,
v: jax.Array,
+ ab: jax.Array | None,
+ segment_ids: SegmentIds | None,
*,
+ causal: bool,
+ sm_scale: jax.Array | None,
kernel_inputs_specs,
kernel_outputs_specs,
) -> pl.CostEstimate | None:
- b, h, tq, dqk = q.shape
- tk = k.shape[-2]
- dv = v.shape[-1]
-
- # Simplify flop computation to include only matmul operations.
- qk_flops = 2 * tq * tk * dqk
- av_flops = 2 * tq * tk * dv
- per_head_flops = qk_flops + av_flops
- flops = b * h * per_head_flops
-
- transcendentals = b * tq * tk * h
+ body_cost = pl.estimate_cost(
+ mha_reference,
+ q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale
+ )
input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs))
output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs))
return pl.CostEstimate(
- flops=flops,
- transcendentals=transcendentals,
+ flops=body_cost.flops,
+ transcendentals=body_cost.transcendentals,
bytes_accessed=input_bytes + output_bytes,
)
@@ -790,6 +787,10 @@ def kv_segment_ids_index_map(
q,
k,
v,
+ ab,
+ segment_ids,
+ causal=causal,
+ sm_scale=sm_scale,
kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids),
kernel_outputs_specs=out_shape,
),
diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py
new file mode 100644
index 000000000000..d1e6bf1fd93d
--- /dev/null
+++ b/jax/experimental/pallas/ops/tpu/random/threefry.py
@@ -0,0 +1,156 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Implementation of the Threefry PRNG as a Pallas kernel."""
+from typing import Sequence
+import jax
+from jax import lax
+from jax._src import prng
+from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
+import jax.numpy as jnp
+import numpy as np
+
+Shape = Sequence[int]
+
+BLOCK_SIZE = (256, 256)
+
+_round_up = lambda x, y: (x + y - 1) // y * y
+
+
+def blocked_iota(block_shape: Shape,
+ total_shape: Shape):
+ """Computes a sub-block of a larger shaped iota.
+
+ Args:
+ block_shape: The output block shape of the iota.
+ total_shape: The total shape of the input tensor.
+ Returns:
+ Result of the blocked iota.
+ """
+ iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
+ multiplier = 1
+ for dim in range(len(block_shape)-1, -1, -1):
+ block_mult = 1
+ counts_lo = lax.broadcasted_iota(
+ dtype=jnp.uint32, shape=block_shape, dimension=dim
+ )
+ iota_data += counts_lo * multiplier * block_mult
+ multiplier *= total_shape[dim]
+ return iota_data
+
+
+def _compute_scalar_offset(iteration_index,
+ total_size: Shape,
+ block_size: Shape):
+ ndims = len(iteration_index)
+ dim_size = 1
+ total_idx = 0
+ for i in range(ndims-1, -1, -1):
+ dim_idx = iteration_index[i] * block_size[i]
+ total_idx += dim_idx * dim_size
+ dim_size *= total_size[i]
+ return total_idx
+
+
+def threefry_2x32_count(key,
+ shape: Shape,
+ unpadded_shape: Shape,
+ block_size: tuple[int, int]):
+ """Generates random bits using the Threefry hash function.
+
+ This function is a fusion of prng.shaped_iota and prng.threefry_2x32 from
+ the JAX core library.
+
+ Args:
+ key: A threefry key of shape (2,).
+ shape: The shape of the output. Must be divisible by `block_size`.
+ unpadded_shape: If `shape` is padded, then this is the shape of the
+ output tensor if it were not padded. This is important for indexing
+ calculations within the kernel. If `shape` is not padded, then this
+ should be equal to `shape`.
+ block_size: The block size of the kernel.
+
+ Returns:
+ A tensor of random bits of shape `shape`.
+ """
+ shape = tuple(shape)
+ if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
+ raise ValueError(
+ f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
+
+ if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
+ raise ValueError(
+ f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
+ grid_dims = shape[:-2] + (
+ shape[-2] // block_size[-2], shape[-1] // block_size[1],)
+
+ def kernel(key_ref, out_ref):
+ counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
+ offset = _compute_scalar_offset(counts_idx, unpadded_shape, block_shape)
+ counts_lo = blocked_iota(block_size, unpadded_shape)
+ counts_lo = counts_lo + offset
+ counts_lo = counts_lo.astype(jnp.uint32)
+ # TODO(justinfu): Support hi bits on count.
+ counts_hi = jnp.zeros_like(counts_lo)
+ k1 = jnp.reshape(key_ref[0, 0], (1, 1))
+ k2 = jnp.reshape(key_ref[0, 1], (1, 1))
+ o1, o2 = prng.threefry2x32_p.bind(
+ k1, k2, counts_hi, counts_lo)
+ out_bits = o1 ^ o2
+ out_ref[...] = out_bits.reshape(out_ref.shape)
+
+ key = key.reshape((1, 2))
+ out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
+ block_shape = (1,) * (len(shape)-2) + block_size
+ result = pl.pallas_call(
+ kernel,
+ in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)],
+ out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs),
+ grid=grid_dims,
+ out_shape=out,
+ )(key)
+ return result
+
+def plthreefry_random_bits(key, bit_width: int, shape: Shape):
+ if bit_width != 32:
+ raise ValueError("Only 32-bit PRNG supported.")
+ if len(shape) == 0:
+ return plthreefry_random_bits(key, bit_width, (1, 1))[0, 0]
+ elif len(shape) == 1:
+ return plthreefry_random_bits(key, bit_width, (1, *shape))[0]
+
+ requires_pad = (
+ shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
+ if requires_pad:
+ padded_shape = tuple(shape[:-2]) + (
+ _round_up(shape[-2], BLOCK_SIZE[-2]),
+ _round_up(shape[-1], BLOCK_SIZE[-1]),
+ )
+ padded_result = threefry_2x32_count(
+ key, padded_shape, shape, block_size=BLOCK_SIZE)
+ return padded_result[..., :shape[-2], :shape[-1]]
+ else:
+ return threefry_2x32_count(key, shape, shape, block_size=BLOCK_SIZE)
+
+
+plthreefry_prng_impl = prng.PRNGImpl(
+ key_shape=(2,),
+ seed=prng.threefry_seed,
+ split=prng.threefry_split,
+ random_bits=plthreefry_random_bits,
+ fold_in=prng.threefry_fold_in,
+ name="pallas_threefry2x32",
+ tag="plfry")
+
+prng.register_prng(plthreefry_prng_impl)
diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py
index 9391d7ddf546..07f631f6ec49 100644
--- a/jax/experimental/shard_map.py
+++ b/jax/experimental/shard_map.py
@@ -51,6 +51,7 @@
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
special, control_flow, ann)
+from jax._src.extend import ffi
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
@@ -482,7 +483,8 @@ def _shard_map_staging(
in_tracers = map(trace.to_jaxpr_tracer, in_tracers)
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
- with core.extend_axis_env_nd(list(mesh.shape.items())):
+ with (core.extend_axis_env_nd(list(mesh.shape.items())),
+ pjit.get_abstract_mesh(in_avals_)):
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
@@ -546,6 +548,8 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
+ # TODO(yashkatariya): Reset the mesh properly based on the input avals if the
+ # mesh of shard_map specifies collective axes.
if config.sharding_in_types.value:
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
@@ -1290,30 +1294,38 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params):
@register_check(control_flow.conditionals.cond_p)
def _cond_rule(mesh, *in_rep, branches):
_, *args_rep = in_rep
- true_out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep)
- false_out_rep = _check_rep(mesh, branches[1].jaxpr, args_rep)
- if not true_out_rep == false_out_rep:
- raise Exception("The true and false branches of cond produced mismatched "
- f"replication types {true_out_rep} and {false_out_rep}. "
- "Please open an issue at "
- "https://github.com/jax-ml/jax/issues, and as a temporary "
- "workaround pass the check_rep=False argument to shard_map")
- return true_out_rep
+ out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep)
+ for branch in branches[1:]:
+ out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep)
+ if not out_rep_ == out_rep:
+ raise Exception("The branches of cond produced mismatched replication "
+ "types. Please open an issue at "
+ "https://github.com/jax-ml/jax/issues, and as a "
+ "temporary workaround pass the check_rep=False argument "
+ "to shard_map")
+ return out_rep
@register_rewrite(control_flow.conditionals.cond_p)
def _cond_rewrite(mesh, in_rep, *args, branches):
pred_rep, *args_rep = in_rep
- _, true_out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep)
- _, false_out_rep = _replication_rewrite_nomatch(mesh, branches[1], args_rep)
- out_rep = map(op.and_, true_out_rep, false_out_rep)
+ _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep)
+ for branch in branches[1:]:
+ _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep)
+ if out_rep:
+ out_rep = map(op.and_, out_rep, out_rep_)
+ else:
+ out_rep = out_rep_
out_rep = map(partial(op.and_, pred_rep), out_rep)
- branches_ = (
- _replication_rewrite_match(mesh, branches[0], args_rep, out_rep),
- _replication_rewrite_match(mesh, branches[1], args_rep, out_rep),
- )
+ branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep)
+ for branch in branches)
out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_)
return out_vals, out_rep
+@register_check(control_flow.conditionals.platform_index_p)
+def _platform_index_rule(mesh, *_, **__):
+ return set(mesh.axis_names)
+register_norewrite(control_flow.conditionals.platform_index_p)
+
@register_rewrite(core.closed_call_p)
def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs):
new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep)
@@ -1363,20 +1375,17 @@ def fwd_jaxpr_thunk_(*zeros):
def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_):
return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep)
-
-# TODO(mattjj): make standard_check handle multiple outputs, share code
@register_check(control_flow.solves.linear_solve_p)
-def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs):
- in_rep_ = [r for r in in_rep if r is not None]
- assert in_rep
- if not in_rep_[:-1] == in_rep_[1:]:
- msg = ("shard_map check_rep rewrite failed. Please open an issue at "
- "https://github.com/jax-ml/jax/issues and as a workaround pass the "
- "check_rep=False argument to shard_map")
- raise Exception(msg)
- return [in_rep_[0]] * len(jaxprs.solve.out_avals)
+def _linear_solve_check(mesh, *in_rep, jaxprs, **_):
+ out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep)
+ return [out_rep] * len(jaxprs.solve.out_avals)
register_standard_rewrite(control_flow.solves.linear_solve_p)
+@register_check(ffi.ffi_call_p)
+def _ffi_call_check(mesh, *in_rep, result_avals, **_):
+ out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep)
+ return [out_rep] * len(result_avals)
+register_standard_rewrite(ffi.ffi_call_p)
del _check_rules[lax.tie_p]
@@ -1541,10 +1550,11 @@ def fun(*res_and_args):
return jaxpr
-def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
+def _unmentioned2(mesh: Mesh, names: AxisNames,
+ auto: frozenset[AxisName]) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-rep case.
- name_set = {n for ns in names.values() for n in ns}
+ name_set = {n for ns in names.values() for n in ns} | auto
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
@@ -1553,7 +1563,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite or dtypes.dtype(x) == dtypes.float0
- else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
+ else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
@@ -1571,7 +1581,7 @@ def fun_trans(out_cts, args):
)
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
- else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns)))
+ else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_names, out)]
return out
diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py
index aaf3791037d0..cd3696d8838c 100644
--- a/jax/lib/xla_client.py
+++ b/jax/lib/xla_client.py
@@ -18,7 +18,6 @@
get_topology_for_devices = _xc.get_topology_for_devices
heap_profile = _xc.heap_profile
mlir_api_version = _xc.mlir_api_version
-ArrayImpl = _xc.ArrayImpl
Client = _xc.Client
CompileOptions = _xc.CompileOptions
DeviceAssignment = _xc.DeviceAssignment
@@ -95,6 +94,11 @@
"XlaComputation is deprecated; use StableHLO instead.",
_xc.XlaComputation,
),
+ # Added Nov 20 2024
+ "ArrayImpl": (
+ "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.",
+ _xc.ArrayImpl,
+ ),
}
import typing as _typing
@@ -106,6 +110,7 @@
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target
shape_from_pyval = _xc.shape_from_pyval
+ ArrayImpl = _xc.ArrayImpl
Device = _xc.Device
FftType = _FftType
PaddingType = _xc.PaddingType
diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py
index 20ce459685aa..52fe94e231d1 100644
--- a/jax/lib/xla_extension.py
+++ b/jax/lib/xla_extension.py
@@ -24,7 +24,6 @@
pmap_lib = _xe.pmap_lib
profiler = _xe.profiler
pytree = _xe.pytree
-ArrayImpl = _xe.ArrayImpl
Device = _xe.Device
DistributedRuntimeClient = _xe.DistributedRuntimeClient
HloModule = _xe.HloModule
@@ -33,6 +32,28 @@
PjitFunctionCache = _xe.PjitFunctionCache
PjitFunction = _xe.PjitFunction
PmapFunction = _xe.PmapFunction
-XlaRuntimeError = _xe.XlaRuntimeError
+_deprecations = {
+ # Added Nov 20 2024
+ "ArrayImpl": (
+ "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.",
+ _xe.ArrayImpl,
+ ),
+ "XlaRuntimeError": (
+ "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.",
+ _xe.XlaRuntimeError,
+ ),
+}
+
+import typing as _typing
+
+if _typing.TYPE_CHECKING:
+ ArrayImpl = _xe.ArrayImpl
+ XlaRuntimeError = _xe.XlaRuntimeError
+else:
+ from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
+
+ __getattr__ = _deprecation_getattr(__name__, _deprecations)
+ del _deprecation_getattr
+del _typing
del _xe
diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py
index 9be73e96adcf..12736c1cd9b1 100644
--- a/jax/numpy/__init__.py
+++ b/jax/numpy/__init__.py
@@ -202,6 +202,7 @@
printoptions as printoptions,
promote_types as promote_types,
put as put,
+ put_along_axis as put_along_axis,
ravel as ravel,
ravel_multi_index as ravel_multi_index,
repeat as repeat,
@@ -273,6 +274,15 @@
except ImportError:
pass
+# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
+try:
+ from jax._src.numpy.lax_numpy import (
+ float8_e3m4 as float8_e3m4,
+ float8_e4m3 as float8_e4m3,
+ )
+except ImportError:
+ pass
+
from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,
diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi
index d391abd46e13..5d357ab1bb03 100644
--- a/jax/numpy/__init__.pyi
+++ b/jax/numpy/__init__.pyi
@@ -29,6 +29,46 @@ _Device = Device
ComplexWarning: type
+class ufunc:
+ def __init__(self, func: Callable[..., Any], /,
+ nin: int, nout: int, *,
+ name: str | None = None,
+ nargs: int | None = None,
+ identity: Any = None,
+ call: Callable[..., Any] | None = None,
+ reduce: Callable[..., Any] | None = None,
+ accumulate: Callable[..., Any] | None = None,
+ at: Callable[..., Any] | None = None,
+ reduceat: Callable[..., Any] | None = None,
+ ): ...
+ @property
+ def nin(self) -> int: ...
+ @property
+ def nout(self) -> int: ...
+ @property
+ def nargs(self) -> int: ...
+ @property
+ def identity(self) -> builtins.bool | int | float: ...
+ def __call__(self, *args: ArrayLike) -> Any: ...
+ def reduce(self, a: ArrayLike, /, *,
+ axis: int | None = 0,
+ dtype: DTypeLike | None = None,
+ out: None = None,
+ keepdims: builtins.bool = False,
+ initial: ArrayLike | None = None,
+ where: ArrayLike | None = None) -> Array: ...
+ def accumulate(self, a: ArrayLike, /, *,
+ axis: int = 0,
+ dtype: DTypeLike | None = None,
+ out: None = None) -> Array: ...
+ def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *,
+ inplace: builtins.bool = True) -> Array: ...
+ def reduceat(self, a: ArrayLike, indices: Any, *,
+ axis: int = 0,
+ dtype: DTypeLike | None = None,
+ out: None = None) -> Array: ...
+ def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ...
+
class BinaryUfunc(Protocol):
@property
def nin(self) -> int: ...
@@ -39,9 +79,10 @@ class BinaryUfunc(Protocol):
@property
def identity(self) -> builtins.bool | int | float: ...
def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ...
- def reduce(self, arr: ArrayLike, /, *,
+ def reduce(self, a: ArrayLike, /, *,
axis: int | None = 0,
dtype: DTypeLike | None = None,
+ out: None = None,
keepdims: builtins.bool = False,
initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
@@ -434,6 +475,8 @@ def fromfile(*args, **kwargs): ...
def fromfunction(function: Callable[..., Array], shape: Any,
*, dtype: DTypeLike = ..., **kwargs) -> Array: ...
def fromiter(*args, **kwargs): ...
+def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
+ *, identity: Any = None) -> ufunc: ...
def fromstring(
string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str
) -> Array: ...
@@ -586,8 +629,8 @@ def log(x: ArrayLike, /) -> Array: ...
def log10(x: ArrayLike, /) -> Array: ...
def log1p(x: ArrayLike, /) -> Array: ...
def log2(x: ArrayLike, /) -> Array: ...
-def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ...
-def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
+logaddexp: BinaryUfunc
+logaddexp2: BinaryUfunc
logical_and: BinaryUfunc
def logical_not(x: ArrayLike, /) -> Array: ...
logical_or: BinaryUfunc
@@ -742,6 +785,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ...) -> Array: ...
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
+def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike,
+ axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ...
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
diff --git a/jax/version.py b/jax/version.py
index c27caf979ddb..3e8a8291ec8d 100644
--- a/jax/version.py
+++ b/jax/version.py
@@ -60,7 +60,11 @@ def _version_from_git_tree(base_version: str) -> str | None:
except:
return None
else:
- return f"{base_version}.dev{datestring}+{commit_hash}"
+ version = f"{base_version}.dev{datestring}+{commit_hash}"
+ suffix = os.environ.get("JAX_CUSTOM_VERSION_SUFFIX", None)
+ if suffix:
+ return version + "." + suffix
+ return version
def _get_version_for_build() -> str:
diff --git a/jaxlib/BUILD b/jaxlib/BUILD
index 8c402cfcefe8..987fe24a8008 100644
--- a/jaxlib/BUILD
+++ b/jaxlib/BUILD
@@ -243,7 +243,6 @@ pybind_extension(
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
"@xla//third_party/python_runtime:headers",
- "@xla//xla:status",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc
index 19b82a5ce149..ed815e1b1bd2 100644
--- a/jaxlib/cpu/lapack_kernels.cc
+++ b/jaxlib/cpu/lapack_kernels.cc
@@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric;
template struct EigenvalueDecompositionHermitian;
template struct EigenvalueDecompositionHermitian;
-// LAPACK uses a packed representation to represent a mixture of real
-// eigenvectors and complex conjugate pairs. This helper unpacks the
-// representation into regular complex matrices.
-template
-static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag,
- const T* packed, std::complex* unpacked) {
- for (int j = 0; j < n;) {
- if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
- // Real values in each row without imaginary part
- // Second row of the imaginary part is not provided
- for (int i = 0; i < n; ++i) {
- unpacked[j * n + i] = {packed[j * n + i], 0.};
- }
- ++j;
- } else {
- // Complex values where the real part is in the jth row
- // and the imaginary part is in the next row (j + 1)
- for (int i = 0; i < n; ++i) {
- const T real_part = packed[j * n + i];
- const T imag_part = packed[(j + 1) * n + i];
- unpacked[j * n + i] = {real_part, imag_part};
- unpacked[(j + 1) * n + i] = {real_part, -imag_part};
- }
- j += 2;
- }
- }
-}
-
// lapack geev
template
diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h
index 7d15e494fffc..cddcb1162120 100644
--- a/jaxlib/cpu/lapack_kernels.h
+++ b/jaxlib/cpu/lapack_kernels.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_
#define JAXLIB_CPU_LAPACK_KERNELS_H_
+#include
#include
#include
#include
@@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian {
// lapack geev
+// LAPACK uses a packed representation to represent a mixture of real
+// eigenvectors and complex conjugate pairs. This helper unpacks the
+// representation into regular complex matrices.
+template
+static void UnpackEigenvectors(Int n, const T* eigenvals_imag,
+ const T* packed, std::complex* unpacked) {
+ for (int j = 0; j < n;) {
+ if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
+ // Real values in each row without imaginary part
+ // Second row of the imaginary part is not provided
+ for (int i = 0; i < n; ++i) {
+ unpacked[j * n + i] = {packed[j * n + i], 0.};
+ }
+ ++j;
+ } else {
+ // Complex values where the real part is in the jth row
+ // and the imaginary part is in the next row (j + 1)
+ for (int i = 0; i < n; ++i) {
+ const T real_part = packed[j * n + i];
+ const T imag_part = packed[(j + 1) * n + i];
+ unpacked[j * n + i] = {real_part, imag_part};
+ unpacked[(j + 1) * n + i] = {real_part, -imag_part};
+ }
+ j += 2;
+ }
+ }
+}
+
template
struct RealGeev {
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a,
diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD
index 34e40d12d5be..afce2c000ecc 100644
--- a/jaxlib/cuda/BUILD
+++ b/jaxlib/cuda/BUILD
@@ -476,6 +476,55 @@ pybind_extension(
],
)
+cc_library(
+ name = "cuda_hybrid_kernels",
+ srcs = ["//jaxlib/gpu:hybrid_kernels.cc"],
+ hdrs = ["//jaxlib/gpu:hybrid_kernels.h"],
+ deps = [
+ ":cuda_gpu_kernel_helpers",
+ ":cuda_vendor",
+ "//jaxlib:ffi_helpers",
+ "//jaxlib/cpu:lapack_kernels",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ "@xla//xla/ffi/api:ffi",
+ ],
+)
+
+pybind_extension(
+ name = "_hybrid",
+ srcs = ["//jaxlib/gpu:hybrid.cc"],
+ copts = [
+ "-fexceptions",
+ "-fno-strict-aliasing",
+ ],
+ features = ["-use_header_modules"],
+ linkopts = select({
+ "@xla//xla/python:use_jax_cuda_pip_rpaths": [
+ "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib",
+ ],
+ "//conditions:default": [],
+ }),
+ module_name = "_hybrid",
+ deps = [
+ ":cuda_gpu_kernel_helpers",
+ ":cuda_hybrid_kernels",
+ ":cuda_vendor",
+ "//jaxlib:kernel_nanobind_helpers",
+ "//jaxlib/cpu:lapack_kernels",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@nanobind",
+ "@xla//xla/ffi/api:ffi",
+ "@xla//xla/tsl/cuda:cudart",
+ ],
+)
+
cc_library(
name = "cuda_gpu_kernels",
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
@@ -633,6 +682,7 @@ py_library(
name = "cuda_gpu_support",
deps = [
":_blas",
+ ":_hybrid",
":_linalg",
":_prng",
":_rnn",
diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD
index 7d50a91cfcda..e888f6a42a9b 100644
--- a/jaxlib/gpu/BUILD
+++ b/jaxlib/gpu/BUILD
@@ -37,6 +37,9 @@ exports_files(srcs = [
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",
+ "hybrid.cc",
+ "hybrid_kernels.cc",
+ "hybrid_kernels.h",
"linalg.cc",
"linalg_kernels.cc",
"linalg_kernels.cu.cc",
diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc
new file mode 100644
index 000000000000..afe95a650d29
--- /dev/null
+++ b/jaxlib/gpu/hybrid.cc
@@ -0,0 +1,60 @@
+/* Copyright 2021 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "nanobind/nanobind.h"
+#include "jaxlib/cpu/lapack_kernels.h"
+#include "jaxlib/gpu/hybrid_kernels.h"
+#include "jaxlib/gpu/vendor.h"
+#include "jaxlib/kernel_nanobind_helpers.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+namespace {
+namespace ffi = xla::ffi;
+namespace nb = nanobind;
+
+void GetLapackKernelsFromScipy() {
+ static bool initialized = false; // Protected by GIL
+ if (initialized) return;
+ nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas");
+ nb::module_ cython_lapack =
+ nb::module_::import_("scipy.linalg.cython_lapack");
+ nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__");
+ auto lapack_ptr = [&](const char* name) {
+ return nb::cast(lapack_capi[name]).data();
+ };
+
+ AssignKernelFn>(lapack_ptr("sgeev"));
+ AssignKernelFn>(lapack_ptr("dgeev"));
+ AssignKernelFn>(lapack_ptr("cgeev"));
+ AssignKernelFn>(
+ lapack_ptr("zgeev"));
+}
+
+NB_MODULE(_hybrid, m) {
+ m.def("initialize", GetLapackKernelsFromScipy);
+ m.def("has_magma", []() { return MagmaLookup().FindMagmaInit().ok(); });
+ m.def("registrations", []() {
+ nb::dict dict;
+ dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
+ dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
+ return dict;
+ });
+}
+
+} // namespace
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc
new file mode 100644
index 000000000000..1ce2e547b11f
--- /dev/null
+++ b/jaxlib/gpu/hybrid_kernels.cc
@@ -0,0 +1,631 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "jaxlib/gpu/hybrid_kernels.h"
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "absl/algorithm/container.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "jaxlib/cpu/lapack_kernels.h"
+#include "jaxlib/ffi_helpers.h"
+#include "jaxlib/gpu/gpu_kernel_helpers.h"
+#include "jaxlib/gpu/vendor.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+
+namespace ffi = ::xla::ffi;
+
+// This helper class is used to define a host buffer that can be copied to and
+// from a device buffer.
+template
+class HostBuffer {
+ public:
+ HostBuffer(std::size_t size) : size_(size) {
+ data_ = std::unique_ptr(new T[size]);
+ }
+
+ absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) {
+ return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T),
+ gpuMemcpyDeviceToHost, stream));
+ }
+
+ absl::Status CopyToDevice(gpuStream_t stream, T* buffer) {
+ return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T),
+ gpuMemcpyHostToDevice, stream));
+ }
+
+ T* get() const { return data_.get(); }
+
+ private:
+ std::unique_ptr data_;
+ size_t size_;
+};
+
+// Forwarded from MAGMA for use as an input parameter.
+typedef enum {
+ MagmaNoVec = 301,
+ MagmaVec = 302,
+} magma_vec_t;
+
+// Compile time lookup of MAGMA function names depending on the data type.
+template
+struct always_false : std::false_type {};
+
+template
+struct MagmaGeev {
+ static_assert(always_false::value, "unsupported data type");
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_sgeev";
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_dgeev";
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_cgeev";
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_zgeev";
+};
+
+MagmaLookup::~MagmaLookup() {
+ if (initialized_) {
+ void* magma_finalize = dlsym(handle_, "magma_finalize");
+ if (magma_finalize != nullptr) {
+ reinterpret_cast(magma_finalize)();
+ }
+ }
+ if (handle_ != nullptr) {
+ dlclose(handle_);
+ }
+}
+
+absl::StatusOr MagmaLookup::FindMagmaInit() {
+ void* magma_init = nullptr;
+ std::vector paths;
+ const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH");
+ if (magma_lib_path != nullptr) {
+ paths.push_back(magma_lib_path);
+ } else {
+ paths.push_back("libmagma.so.2");
+ paths.push_back("libmagma.so");
+ paths.push_back(nullptr);
+ }
+ for (const auto& path : paths) {
+ handle_ = dlopen(path, RTLD_LAZY);
+ if (handle_ != nullptr) {
+ magma_init = dlsym(handle_, "magma_init");
+ if (magma_init != nullptr) {
+ if (path != nullptr) {
+ lib_path_ = std::string(path);
+ }
+ break;
+ }
+ }
+ }
+ if (handle_ == nullptr || magma_init == nullptr) {
+ return absl::InternalError(
+ "Unable to dlopen a MAGMA shared library that defines a magma_init "
+ "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to "
+ "specify an explicit path to the library.");
+ }
+ return magma_init;
+}
+
+absl::Status MagmaLookup::Initialize() {
+ if (failed_) {
+ return absl::InternalError("MAGMA initialization was unsuccessful.");
+ }
+ if (!initialized_) {
+ auto maybe_magma_init = FindMagmaInit();
+ if (!maybe_magma_init.ok()) {
+ failed_ = true;
+ return maybe_magma_init.status();
+ }
+ reinterpret_cast(maybe_magma_init.value())();
+ initialized_ = true;
+ }
+ return absl::OkStatus();
+}
+
+absl::StatusOr MagmaLookup::Find(const char name[]) {
+ if (!initialized_) {
+ return absl::InternalError("MAGMA support has not been initialized.");
+ }
+
+ auto it = symbols_.find(name);
+ if (it != symbols_.end()) return it->second;
+
+ void* symbol = dlsym(handle_, name);
+ if (symbol == nullptr) {
+ if (lib_path_.has_value()) {
+ return absl::InternalError(absl::StrFormat(
+ "Unable to load the symbol '%s' from the MAGMA library at '%s'.",
+ name, lib_path_.value()));
+
+ } else {
+ return absl::InternalError(absl::StrFormat(
+ "Unable to load a globally defined symbol called '%s'. Use the "
+ "JAX_GPU_MAGMA_PATH environment variable to specify an explicit "
+ "path to the library.",
+ name));
+ }
+ }
+
+ symbols_.insert({name, symbol});
+ return symbol;
+}
+
+// Lookup the MAGMA symbol for the given function name. This function only
+// dlopen the MAGMA library once per process.
+absl::StatusOr FindMagmaSymbol(const char name[]) {
+ static absl::Mutex mu;
+ static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu);
+ absl::MutexLock lock(&mu);
+ auto status = lookup.Initialize();
+ if (!status.ok()) {
+ return status;
+ }
+ return lookup.Find(name);
+}
+
+// Real-valued eigendecomposition
+
+template
+class EigRealHost {
+ using Real = ffi::NativeType;
+
+ public:
+ explicit EigRealHost() = default;
+ EigRealHost(EigRealHost&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? 'V' : 'N';
+ jobvr_ = right ? 'V' : 'N';
+ int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize(
+ n, static_cast(jobvl_),
+ static_cast(jobvr_));
+ return MaybeCastNoOverflow(lwork);
+ }
+
+ void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work,
+ int lwork, int* info) {
+ EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi,
+ vl, &n_, vr, &n_, work, &lwork, info);
+ }
+
+ private:
+ int n_;
+ char jobvl_, jobvr_;
+};
+
+template
+class EigRealMagma {
+ using Real = ffi::NativeType;
+ using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*,
+ int, Real*, int, Real*, int, int*);
+
+ public:
+ explicit EigRealMagma() = default;
+ EigRealMagma(EigRealMagma&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? MagmaVec : MagmaNoVec;
+ jobvr_ = right ? MagmaVec : MagmaNoVec;
+
+ auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name);
+ if (!maybe_ptr.ok()) return maybe_ptr.status();
+ fn_ = reinterpret_cast(*maybe_ptr);
+
+ int query_info;
+ Real query_host;
+ fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n,
+ &query_host, -1, &query_info);
+ return static_cast(query_host);
+ }
+
+ void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work,
+ int lwork, int* info) {
+ fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info);
+ }
+
+ private:
+ int n_;
+ magma_vec_t jobvl_, jobvr_;
+ Fn* fn_ = nullptr;
+};
+
+template
+ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream,
+ bool left, bool right, ffi::AnyBuffer x,
+ ffi::Result wr,
+ ffi::Result wi,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ using Real = ffi::NativeType;
+ using Complex = ffi::NativeType;
+
+ auto x_host = HostBuffer(x.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ x_host.CopyFromDevice(stream, x.typed_data()));
+
+ auto wr_host = HostBuffer(batch * cols);
+ auto wi_host = HostBuffer(batch * cols);
+ auto vl_host = HostBuffer(batch * cols * cols);
+ auto vr_host = HostBuffer(batch * cols * cols);
+ auto info_host = HostBuffer(batch);
+
+ FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols));
+ FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right));
+ auto work_host = AllocateScratchMemory(lwork);
+ auto work_left = AllocateScratchMemory(cols * cols);
+ auto work_right = AllocateScratchMemory(cols * cols);
+
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ const auto is_finite = [](auto* data, int64_t size) {
+ return absl::c_all_of(absl::MakeSpan(data, size),
+ [](auto value) { return std::isfinite(value); });
+ };
+
+ for (int64_t i = 0; i < batch; ++i) {
+ if (is_finite(x_host.get() + i * cols * cols, cols * cols)) {
+ impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols,
+ wi_host.get() + i * cols, work_left.get(), work_right.get(),
+ work_host.get(), lwork, info_host.get() + i);
+ if (info_host.get()[i] == 0) {
+ if (left) {
+ UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(),
+ vl_host.get() + i * cols * cols);
+ }
+ if (right) {
+ UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(),
+ vr_host.get() + i * cols * cols);
+ }
+ }
+ } else {
+ info_host.get()[i] = -4;
+ }
+ }
+
+ FFI_RETURN_IF_ERROR_STATUS(
+ wr_host.CopyToDevice(stream, wr->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(
+ wi_host.CopyToDevice(stream, wi->typed_data()));
+ if (left) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vl_host.CopyToDevice(stream, vl->typed_data()));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vr_host.CopyToDevice(stream, vr->typed_data()));
+ }
+ FFI_RETURN_IF_ERROR_STATUS(
+ info_host.CopyToDevice(stream, info->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ return ffi::Error::Success();
+}
+
+ffi::Error EigRealDispatch(gpuStream_t stream, std::string_view magma,
+ bool left, bool right, ffi::AnyBuffer x,
+ ffi::Result wr,
+ ffi::Result wi,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ auto dataType = x.element_type();
+ if (dataType != wr->element_type() || dataType != wi->element_type() ||
+ ffi::ToComplex(dataType) != vl->element_type() ||
+ ffi::ToComplex(dataType) != vr->element_type()) {
+ return ffi::Error::InvalidArgument(
+ "The inputs and outputs to eig must have the same element type");
+ }
+
+ FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
+ SplitBatch2D(x.dimensions()));
+ if (rows != cols) {
+ return ffi::Error::InvalidArgument(
+ "The input matrix to eig must be square");
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig"));
+ FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig"));
+ if (left) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig"));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig"));
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig"));
+
+ bool use_magma = magma == "on";
+ if (magma == "auto" && cols >= 2048) {
+ use_magma = FindMagmaSymbol("magma_init").ok();
+ }
+
+ switch (dataType) {
+ case ffi::F32:
+ if (use_magma) {
+ return EigReal(EigRealMagma(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ } else {
+ return EigReal(EigRealHost(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ }
+ case ffi::F64:
+ if (use_magma) {
+ return EigReal(EigRealMagma(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ } else {
+ return EigReal(EigRealHost(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ }
+ default:
+ return ffi::Error::InvalidArgument(absl::StrFormat(
+ "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType)));
+ }
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch,
+ ffi::Ffi::Bind()
+ .Ctx>()
+ .Attr("magma")
+ .Attr("left")
+ .Attr("right")
+ .Arg() // x
+ .Ret() // wr
+ .Ret() // wi
+ .Ret() // vl
+ .Ret() // vr
+ .Ret>() // info
+);
+
+// Complex-valued eigendecomposition
+
+template
+class EigCompHost {
+ using Real = ffi::NativeType;
+ using Complex = ffi::NativeType;
+
+ public:
+ explicit EigCompHost() = default;
+ EigCompHost(EigCompHost&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? 'V' : 'N';
+ jobvr_ = right ? 'V' : 'N';
+ int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize(
+ n, static_cast(jobvl_),
+ static_cast(jobvr_));
+ return MaybeCastNoOverflow(lwork);
+ }
+
+ void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work,
+ int lwork, Real* rwork, int* info) {
+ EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_,
+ w, vl, &n_, vr, &n_, work,
+ &lwork, rwork, info);
+ }
+
+ private:
+ int n_;
+ char jobvl_, jobvr_;
+};
+
+template
+class EigCompMagma {
+ using Real = ffi::NativeType;
+ using Complex = ffi::NativeType;
+ using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*,
+ Complex*, int, Complex*, int, Complex*, int, Real*, int*);
+
+ public:
+ explicit EigCompMagma() = default;
+ EigCompMagma(EigCompMagma&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? MagmaVec : MagmaNoVec;
+ jobvr_ = right ? MagmaVec : MagmaNoVec;
+ lda_ = std::max(n_, 1);
+ ldvl_ = left ? n_ : 1;
+ ldvr_ = right ? n_ : 1;
+
+ auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name);
+ if (!maybe_ptr.ok()) return maybe_ptr.status();
+ fn_ = reinterpret_cast(*maybe_ptr);
+
+ int query_info;
+ Complex query_host;
+ fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr,
+ ldvr_, &query_host, -1, nullptr, &query_info);
+ return static_cast(query_host.real());
+ }
+
+ void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work,
+ int lwork, Real* rwork, int* info) {
+ fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork,
+ rwork, info);
+ }
+
+ private:
+ int n_, lda_, ldvl_, ldvr_;
+ magma_vec_t jobvl_, jobvr_;
+ Fn* fn_ = nullptr;
+};
+
+template
+ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream,
+ bool left, bool right, ffi::AnyBuffer x,
+ ffi::Result w,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ using Complex = ffi::NativeType;
+
+ auto x_host = HostBuffer(x.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ x_host.CopyFromDevice(stream, x.typed_data()));
+
+ auto w_host = HostBuffer(batch * cols);
+ auto vl_host = HostBuffer(batch * cols * cols);
+ auto vr_host = HostBuffer(batch * cols * cols);
+ auto info_host = HostBuffer(batch);
+
+ FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols));
+ FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right));
+ auto work_host = AllocateScratchMemory(lwork);
+ auto rwork_host =
+ AllocateScratchMemory(2 * cols * cols);
+
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ const auto is_finite = [](auto* data, int64_t size) {
+ return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) {
+ return std::isfinite(z.real()) && std::isfinite(z.imag());
+ });
+ };
+
+ for (int64_t i = 0; i < batch; ++i) {
+ if (is_finite(x_host.get() + i * cols * cols, cols * cols)) {
+ impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols,
+ vl_host.get() + i * cols * cols,
+ vr_host.get() + i * cols * cols, work_host.get(), lwork,
+ rwork_host.get(), info_host.get() + i);
+ } else {
+ info_host.get()[i] = -4;
+ }
+ }
+
+ FFI_RETURN_IF_ERROR_STATUS(
+ w_host.CopyToDevice(stream, w->typed_data()));
+ if (left) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vl_host.CopyToDevice(stream, vl->typed_data()));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vr_host.CopyToDevice(stream, vr->typed_data()));
+ }
+ FFI_RETURN_IF_ERROR_STATUS(
+ info_host.CopyToDevice(stream, info->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ return ffi::Error::Success();
+}
+
+ffi::Error EigCompDispatch(gpuStream_t stream, std::string_view magma,
+ bool left, bool right, ffi::AnyBuffer x,
+ ffi::Result w,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ auto dataType = x.element_type();
+ if (dataType != w->element_type() || dataType != vl->element_type() ||
+ dataType != vr->element_type()) {
+ return ffi::Error::InvalidArgument(
+ "The inputs and outputs to eig must have the same element type");
+ }
+
+ FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
+ SplitBatch2D(x.dimensions()));
+ if (rows != cols) {
+ return ffi::Error::InvalidArgument(
+ "The input matrix to eig must be square");
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig"));
+ if (left) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig"));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig"));
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig"));
+
+ bool use_magma = magma == "on";
+ if (magma == "auto" && cols >= 2048) {
+ use_magma = FindMagmaSymbol("magma_init").ok();
+ }
+
+ switch (dataType) {
+ case ffi::C64:
+ if (use_magma) {
+ return EigComp(EigCompMagma(), batch, cols, stream,
+ left, right, x, w, vl, vr, info);
+ } else {
+ return EigComp(EigCompHost(), batch, cols, stream,
+ left, right, x, w, vl, vr, info);
+ }
+ case ffi::C128:
+ if (use_magma) {
+ return EigComp(EigCompMagma(), batch, cols,
+ stream, left, right, x, w, vl, vr, info);
+ } else {
+ return EigComp(EigCompHost(), batch, cols, stream,
+ left, right, x, w, vl, vr, info);
+ }
+ default:
+ return ffi::Error::InvalidArgument(absl::StrFormat(
+ "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType)));
+ }
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch,
+ ffi::Ffi::Bind()
+ .Ctx>()
+ .Attr("magma")
+ .Attr("left")
+ .Attr("right")
+ .Arg() // x
+ .Ret() // w
+ .Ret() // vl
+ .Ret() // vr
+ .Ret>() // info
+);
+
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h
new file mode 100644
index 000000000000..2890837a2bd5
--- /dev/null
+++ b/jaxlib/gpu/hybrid_kernels.h
@@ -0,0 +1,55 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_
+#define JAXLIB_GPU_HYBRID_KERNELS_H_
+
+#include
+#include
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "jaxlib/gpu/vendor.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+
+// The MagmaLookup class is used for dlopening the MAGMA shared library,
+// initializing it, and looking up MAGMA symbols.
+class MagmaLookup {
+ public:
+ explicit MagmaLookup() = default;
+ ~MagmaLookup();
+ absl::StatusOr FindMagmaInit();
+ absl::Status Initialize();
+ absl::StatusOr Find(const char name[]);
+
+ private:
+ bool initialized_ = false;
+ bool failed_ = false;
+ void* handle_ = nullptr;
+ std::optional lib_path_ = std::nullopt;
+ absl::flat_hash_map symbols_;
+};
+
+XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp);
+
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
+
+#endif // JAXLIB_GPU_HYBRID_KERNELS_H_
diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py
index 03fd43e9ef89..59819f1fc914 100644
--- a/jaxlib/gpu_solver.py
+++ b/jaxlib/gpu_solver.py
@@ -56,6 +56,21 @@
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
+for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
+ try:
+ _cuhybrid = importlib.import_module(
+ f"{cuda_module_name}._hybrid", package="jaxlib"
+ )
+ except ImportError:
+ _cuhybrid = None
+ else:
+ break
+
+if _cuhybrid:
+ for _name, _value in _cuhybrid.registrations().items():
+ xla_client.register_custom_call_target(_name, _value, platform="CUDA",
+ api_version=1)
+
try:
from .rocm import _blas as _hipblas # pytype: disable=import-error
except ImportError:
@@ -88,6 +103,34 @@
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
+for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
+ try:
+ _hiphybrid = importlib.import_module(
+ f"{rocm_module_name}._hybrid", package="jaxlib"
+ )
+ except ImportError:
+ _hiphybrid = None
+ else:
+ break
+
+if _hiphybrid:
+ for _name, _value in _hiphybrid.registrations().items():
+ xla_client.register_custom_call_target(_name, _value, platform="ROCM",
+ api_version=1)
+
+def initialize_hybrid_kernels():
+ if _cuhybrid:
+ _cuhybrid.initialize()
+ if _hiphybrid:
+ _hiphybrid.initialize()
+
+def has_magma():
+ if _cuhybrid:
+ return _cuhybrid.has_magma()
+ if _hiphybrid:
+ return _hiphybrid.has_magma()
+ return False
+
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
return np.finfo(dtype).dtype
diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl
index b5bfe733b992..2bae7ab2a203 100644
--- a/jaxlib/jax.bzl
+++ b/jaxlib/jax.bzl
@@ -66,6 +66,7 @@ _py_deps = {
"filelock": ["@pypi_filelock//:pkg"],
"flatbuffers": ["@pypi_flatbuffers//:pkg"],
"hypothesis": ["@pypi_hypothesis//:pkg"],
+ "magma": [],
"matplotlib": ["@pypi_matplotlib//:pkg"],
"mpmath": [],
"opt_einsum": ["@pypi_opt_einsum//:pkg"],
diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD
index 14f3ee13c0f5..da7498ed437d 100644
--- a/jaxlib/mosaic/BUILD
+++ b/jaxlib/mosaic/BUILD
@@ -62,6 +62,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h
index 66217858fa7d..6edad713b17a 100644
--- a/jaxlib/mosaic/dialect/tpu/layout.h
+++ b/jaxlib/mosaic/dialect/tpu/layout.h
@@ -24,7 +24,6 @@ limitations under the License.
#include
#include
-#include "absl/log/check.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
@@ -39,6 +38,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "absl/log/check.h"
namespace mlir::tpu {
@@ -259,18 +259,23 @@ class VectorLayout {
int layout_rank() const { return layout_rank(implicit_dim_); }
bool operator==(const VectorLayout &other) const;
- bool operator!=(const VectorLayout &other) const {
- return !(*this == other);
- }
-
- // How many tiles fit in each vector register.
- int64_t tilesPerVreg(const std::array target_shape) const {
- const int64_t tile_elems = tiling_[0] * tiling_[1];
- const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1];
+ bool operator!=(const VectorLayout &other) const { return !(*this == other); }
+
+ static int64_t tilesPerVreg(const std::array target_shape,
+ const int8_t bitwidth,
+ const std::array tiling) {
+ CHECK_NE(0, bitwidth) << "bitwidth cannot be 0";
+ const int64_t tile_elems = tiling[0] * tiling[1];
+ const int64_t vreg_capacity =
+ (32 / bitwidth) * target_shape[0] * target_shape[1];
const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems);
CHECK_EQ(rem, 0);
return tiles_per_vreg;
}
+ // How many tiles fit in each vector register.
+ int64_t tilesPerVreg(const std::array target_shape) const {
+ return VectorLayout::tilesPerVreg(target_shape, bitwidth_, tiling_);
+ }
int64_t sublanesPerTile(const std::array target_shape) const {
auto [sublanes_per_tile, rem] =
@@ -283,8 +288,16 @@ class VectorLayout {
//
// We never reuse the same vector register to store data of multiple rows,
// so only the minormost dimension can increase.
+ static std::array vregSlice(std::array target_shape,
+ const int8_t bitwidth,
+ const std::array tiling) {
+ return {
+ tiling[0],
+ VectorLayout::tilesPerVreg(target_shape, bitwidth, tiling) * tiling[1]};
+ }
+
std::array vregSlice(std::array target_shape) const {
- return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
+ return VectorLayout::vregSlice(target_shape, bitwidth_, tiling_);
}
template
diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td
index b312bca7a7d3..8a4f573bce24 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu.td
+++ b/jaxlib/mosaic/dialect/tpu/tpu.td
@@ -214,6 +214,22 @@ def TPU_LoadOp : TPU_Op<"load"> {
}];
}
+// TODO(jevinjiang): migrate tpu.strided_store to general vector store op.
+def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> {
+ let arguments = (ins
+ AnyVector:$valueToStore,
+ AnyMemRef:$base,
+ Variadic:$indices,
+ DenseI32ArrayAttr:$strides,
+ Optional:$mask // Elementwise mask.
+ );
+ let results = (outs);
+ let assemblyFormat = [{
+ $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
+ }];
+ let hasVerifier = 1;
+}
+
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
let arguments = (ins
AnyMemRef:$base,
@@ -637,12 +653,18 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> {
MemRefOf<[TPU_SemaphoreType]>:$semaphore,
I32:$amount,
Optional:$device_id, // For remote DMAs
- Optional:$core_id // For megacore
+ Optional:$core_id, // For megacore
+ OptionalAttr:$core_type
);
- let assemblyFormat = [{
- $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? attr-dict `:` type($semaphore)
+let assemblyFormat = [{
+ $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore)
}];
let hasVerifier = 1;
+ let builders = [
+ // A backward-compatible builder that sets `core_type` to nullptr.
+ OpBuilder<(ins "Value":$semaphore, "Value":$amount,
+ "Value":$device_id, "Value":$core_id)>,
+ ];
}
def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
@@ -738,6 +760,7 @@ def TPU_LogOp : TPU_Op<"log"> {
);
let results = (outs);
let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }];
+ let hasVerifier = 1;
}
def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> {
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
index 10ab154b7c10..92e8953837e3 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
+++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "absl/hash/hash.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc"
#include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc"
#include "xla/layout.h"
@@ -81,6 +82,15 @@ void TPUDialect::initialize() {
return mlir::cast(attr).getValue();
}
+FailureOr> GetCoreTypeOfParentFunc(Operation &op) {
+ mlir::Operation *func_op = op.getParentOfType();
+ if (func_op == nullptr) {
+ return op.emitError() << "Operation " << op.getName()
+ << " is not inside a func.func";
+ }
+ return TPUDialect::GetCoreTypeAttr(func_op);
+}
+
void VectorLayoutAttr::print(AsmPrinter &printer) const {
printer << '<';
printer << getLayout();
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
index dbb2ddaa5853..a8569acc6239 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
+++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
@@ -19,6 +19,7 @@ limitations under the License.
#include
#include
#include
+#include
#include
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -94,6 +95,10 @@ std::unique_ptr> createDebugAssertInsertionPass();
#define GEN_PASS_DECL_MOSAICSERDEPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
+// Determine the core type of the given op based on the `tpu.core_type`
+// annotation of its parent function.
+FailureOr> GetCoreTypeOfParentFunc(Operation &op);
+
// Changes the memory space of the value and propagates it through the program.
LogicalResult specializeMemorySpace(TypedValue value,
MemorySpace memory_space);
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
index 6f690f6a0fcb..3271c0874572 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
+++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
@@ -28,9 +28,12 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
+#include "absl/strings/str_format.h"
+#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/IRMapping.h"
+#include "mlir/include/mlir/IR/OperationSupport.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
@@ -440,6 +443,31 @@ LogicalResult StridedStoreOp::verify() {
getValueToStore().getType());
}
+LogicalResult VectorStoreOp::verify() {
+ if (!getStrides().empty()) {
+ return emitError("Not implemented: general vector store with strides.");
+ }
+ VectorType value_ty = getValueToStore().getType();
+ MemRefType ref_ty = getBase().getType();
+
+ if (value_ty.getElementType() != ref_ty.getElementType()) {
+ return emitOpError(
+ "Expected base and valueToStore element type should match");
+ }
+ if (llvm::size(getIndices()) != ref_ty.getRank()) {
+ return emitOpError("Expected ") << ref_ty.getRank() << " indices";
+ }
+ if (getMask()) {
+ if (value_ty.getElementTypeBitWidth() != 32) {
+ return emitError(
+ "Not implemented: masked store with non-32-bit element type");
+ }
+ if (value_ty.getShape() != getMask().getType().getShape())
+ return emitOpError("Expected valueToStore shape to match mask shape");
+ }
+ return success();
+}
+
LogicalResult ReinterpretCastOp::verify() {
auto source_type = getMemRefType(getInput());
auto target_type = getType();
@@ -468,7 +496,7 @@ LogicalResult verifyRotateOp(Op op) {
}
if (op.getStride().has_value() != op.getStrideDimension().has_value()) {
op.emitOpError(
- "Expected either none or both stride and stride dimension are "
+ "Expected either none or both stride and stride dimension are "
"present");
return failure();
}
@@ -812,11 +840,42 @@ LogicalResult GetBarrierSemaphoreOp::verify() {
return success();
}
+void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state,
+ Value semaphore, Value amount, Value device_id,
+ Value core_id) {
+ build(builder, state, semaphore, amount, device_id, core_id,
+ /*core_type=*/nullptr);
+}
+
LogicalResult SemaphoreSignalOp::verify() {
auto sem_type = getMemRefType(getSemaphore());
if (sem_type.getRank() != 0) {
return emitOpError("Semaphore reference must be rank 0");
}
+
+ FailureOr> issuing_core_type_maybe =
+ GetCoreTypeOfParentFunc(**this);
+ if (failed(issuing_core_type_maybe)) {
+ return issuing_core_type_maybe;
+ }
+ CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc);
+ CoreType target_core_type = getCoreType().value_or(issuing_core_type);
+
+ if (getCoreId() == nullptr && getDeviceId() == nullptr) {
+ if (target_core_type != issuing_core_type) {
+ return emitOpError(
+ absl::StrFormat("Target core type (%s) must match source core type "
+ "(%s) when device_id and core_id are not specified",
+ stringifyCoreType(target_core_type),
+ stringifyCoreType(issuing_core_type)));
+ }
+ }
+ if ((issuing_core_type == CoreType::kTc &&
+ target_core_type == CoreType::kScScalarSubcore) ||
+ (issuing_core_type == CoreType::kScScalarSubcore &&
+ target_core_type == CoreType::kTc)) {
+ return emitOpError("Signalling between TC and SC is not implemented");
+ }
return success();
}
@@ -976,6 +1035,30 @@ LogicalResult ConcatenateOp::verify() {
return success();
}
+LogicalResult LogOp::verify() {
+ FailureOr> logging_core_type_maybe =
+ GetCoreTypeOfParentFunc(**this);
+ if (failed(logging_core_type_maybe)) {
+ return failure();
+ }
+ CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc);
+ if ((logging_core_type == CoreType::kScScalarSubcore ||
+ logging_core_type == CoreType::kScVectorSubcore) &&
+ getFormattedAttr() != nullptr && getFormattedAttr().getValue()) {
+ return emitOpError("Formatted logging is not supported on SC");
+ }
+ switch (logging_core_type) {
+ case CoreType::kTc:
+ case CoreType::kScScalarSubcore:
+ return success();
+ case CoreType::kScVectorSubcore:
+ return emitOpError("Log op is not supported on the SC vector subcore");
+ }
+ return emitOpError(
+ absl::StrFormat("Unexpected core type: %s",
+ stringifyCoreType(logging_core_type_maybe->value())));
+}
+
} // namespace tpu
} // namespace mlir
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
index 8792503f4636..8ade7450881a 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
+++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
@@ -2554,7 +2554,10 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(res_layout.has_value());
auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank();
- if (dimension >= num_untiled_dims) {
+ if (res_ty.getRank() == 1 &&
+ res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) {
+ tiling_dim = 1;
+ } else if (dimension >= num_untiled_dims) {
tiling_dim = dimension - num_untiled_dims;
}
@@ -2576,6 +2579,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
return op.emitOpError("Not implemented: result/input offsets mismatch.");
}
+ if (layout.implicit_dim() != res_layout->implicit_dim()) {
+ return op.emitOpError(
+ "Not implemented: result/input implicit dim mismatch.");
+ }
+
if (i > 1) {
auto curr_offsets = layout.offsets();
auto last_operand_offsets = layouts_in[i - 1]->offsets();
@@ -2611,29 +2619,47 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
if (!tiling_dim.has_value()) {
out_vregs = concatenate(operand_vregs, dimension);
} else {
- if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
+ bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 &&
+ res_layout->implicit_dim() ==
+ VectorLayout::ImplicitDim::kNone;
+ if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor ||
+ is_rank1_with_no_implicit_dim) {
return op.emitOpError("Not implemented: implicit dim");
}
+ if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
+ res_layout->bitwidth() != 32) {
+ return op.emitOpError(
+ "Not implemented: only 32-bit bitwidth supported for SecondMinor "
+ "implicit dim");
+ }
if (res_layout->offsets()[tiling_dim.value()] != 0) {
return op.emitOpError("Not implemented: result non-zero offset.");
}
- if (!res_layout->hasNativeTiling(ctx.target_shape)) {
+ if (!res_layout->hasNativeTiling(ctx.target_shape) &&
+ res_ty.getRank() != 1) {
return op.emitOpError("Not implemented: Non native tiling in concat.");
}
int64_t offset_at_dim = 0;
{
for (int i = 0; i < op.getNumOperands(); ++i) {
- auto operand = op.getOperand(i);
- auto const &layout = *layouts_in[i];
-
- auto vty = cast(operand.getType());
- auto shape = vty.getShape();
-
- auto starting_point = offset_at_dim;
- auto offset_amount =
- starting_point % layout.tiling()[tiling_dim.value()];
- if (offset_amount != layout.offsets()[tiling_dim.value()]) {
+ Value operand = op.getOperand(i);
+ const Layout &layout = *layouts_in[i];
+ xla::Array vreg_array = operand_vregs[i];
+ std::array vreg_slice = layout->vregSlice(ctx.target_shape);
+ std::array tiling = layout->tiling();
+
+ VectorType vty = cast(operand.getType());
+ ArrayRef shape = vty.getShape();
+
+ int64_t starting_point = offset_at_dim;
+ int64_t offset_amount =
+ starting_point % vreg_slice[tiling_dim.value()];
+ if (offset_amount >= tiling[tiling_dim.value()]) {
+ return op.emitError(
+ "Not implemented: Input offsets outside of the first tile");
+ }
+ if (offset_amount != layout->offsets()[tiling_dim.value()]) {
return op.emitOpError(
"Not implemented: Relayout not called, unaligned dims "
"concatenated without proper offsets. Ensure that "
@@ -2648,9 +2674,12 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
for (size_t i = 0; i < operand_vregs.size(); ++i) {
auto &vreg = operand_vregs[i];
const auto &layout = layouts_in[i];
+ const int packing = res_layout->packing();
- if (layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
- return op.emitOpError("Not implemented: implicit dim");
+ if (layout->tiling()[0] % packing != 0) {
+ return op.emitOpError(
+ "Illegal tiling: Non-native tiling in concat - this should "
+ "have been caught earlier!");
}
const int64_t operand_offset = *layout->offsets()[tiling_dim.value()];
@@ -2663,7 +2692,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
}
const auto bitwidth = res_ty.getElementTypeBitWidth();
- const int packing = res_layout->packing();
SmallVector out_idx;
vreg.Each([&](absl::Span idx, Value *v) {
out_idx.assign(idx.begin(), idx.end());
@@ -2694,7 +2722,7 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
mask = builder.create(
op.getLoc(), vmask_ty,
ArrayRef{boundIdxConst(0), boundIdxConst(0)},
- ArrayRef{boundIdxConst(layout->tiling()[0]),
+ ArrayRef{boundIdxConst(layout->tiling()[0] / packing),
boundIdxConst(operand_offset)});
}
// Blend the current value with the existing value in the output.
@@ -4172,18 +4200,15 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
shape_cast_op->erase();
return success();
}
-LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
- const ArrayRef layouts_in,
- const ArrayRef layouts_out) {
- TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
- MLIRContext *const mlir_ctx = op.getContext();
- TPU_ASSERT_OP(layouts_in.front().has_value());
- TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(),
- [&](const Layout &l) { return l.has_value(); }));
+
+template
+LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op,
+ const VectorLayout &to_store_layout,
+ TypedValue store_mask = nullptr) {
+ Operation &op = *(store_op.getOperation());
+ MLIRContext *const mlir_ctx = store_op.getContext();
ImplicitLocOpBuilder builder(op.getLoc(), &op);
- vector::StoreOp store_op = cast(op);
const VectorType ty = store_op.getValueToStore().getType();
- const VectorLayout &to_store_layout = *layouts_in.front();
const auto memref_ty = getMemRefType(store_op.getBase());
if (!ty.getRank()) {
return op.emitOpError("Not implemented: scalar stores to vmem");
@@ -4280,10 +4305,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
} else {
// Convert dynamic store to dynamic slice + static store. This saves us a
// bunch of scalar core work.
- auto slice_result =
- sliceRef(builder, store_op.getBase(),
- store_op.getVectorType().getShape(), store_op.getIndices(),
- ArrayRef(memref_tiling).take_back(tiled_dims));
+ auto slice_result = sliceRef(
+ builder, store_op.getBase(), ty.getShape(), store_op.getIndices(),
+ ArrayRef