Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 11/22/24 upstream sync #148

Merged
merged 155 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
155 commits
Select commit Hold shift + click to select a range
6842848
Add a link to Intel plugin for JAX
mini-goel Oct 28, 2024
78da9fa
Add float8_e4m3 and float8_e3m4 types support
superbobry Oct 7, 2024
e6f6a8a
Move Control Flow text from Sharp Bits into its own tutorial.
emilyfertig Nov 6, 2024
d823f17
jnp.logaddexp2: simplify implementation
jakevdp Nov 12, 2024
d0f3666
Update array-api-tests commit
jakevdp Nov 14, 2024
303b792
Merge pull request #24864 from jakevdp:logaddexp2
Google-ML-Automation Nov 14, 2024
1f114b1
Add numpy.put_along_axis.
carlosgmartin Nov 14, 2024
04d339d
Merge pull request #24904 from jakevdp:array-api
Google-ML-Automation Nov 14, 2024
4a3e115
cleanup: delete unused argument from internal reduction helper
jakevdp Nov 14, 2024
c40d405
Update XLA dependency to use revision
Google-ML-Automation Nov 14, 2024
41a0493
Add shard map replication rule for ffi_call.
dfm Nov 13, 2024
8e29212
Merge pull request #24567 from Intel-tensorflow:minigoel/intel-plugin
Google-ML-Automation Nov 14, 2024
5764afb
Merge pull request #24905 from jakevdp:old-arg
Google-ML-Automation Nov 14, 2024
4fe9164
Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
Google-ML-Automation Nov 15, 2024
a115b2c
Update array-api-tests commit
jakevdp Nov 15, 2024
c6051b3
Merge pull request #24881 from dfm:ffi-call-rep-rule
Google-ML-Automation Nov 15, 2024
1c31860
Merge pull request #24907 from jakevdp:array-api
Google-ML-Automation Nov 15, 2024
4511f0c
Merge pull request #24862 from emilyfertig:emilyaf-control-flow-tutorial
Google-ML-Automation Nov 15, 2024
9a0e9e5
[sharding_in_types] Handle collective axes in lowering rules more gen…
yashk2810 Nov 15, 2024
f652b6a
Set __module__ attribute for objects in jax.numpy
jakevdp Nov 15, 2024
1471702
[Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to t…
Google-ML-Automation Nov 15, 2024
23e9142
Lower threefry as an out-of-line MLIR function on TPU.
hawkinsp Nov 15, 2024
d808500
Merge pull request #24913 from hawkinsp:threefry
Google-ML-Automation Nov 15, 2024
5f94284
Add missing functions to jax.numpy type interface
jakevdp Nov 15, 2024
5f1e3f5
Add an example on logical operators to the tutorial.
emilyfertig Nov 15, 2024
1780ff2
Update XLA dependency to use revision
Google-ML-Automation Nov 15, 2024
81cdc88
DOC: update main landing page style
barnesjoseph Nov 15, 2024
605c605
Merge pull request #24918 from emilyfertig:emilyaf-logical-op-example
Google-ML-Automation Nov 15, 2024
1aa5de6
Merge pull request #24914 from jakevdp:fix-pyi
Google-ML-Automation Nov 15, 2024
225a2a5
Consolidate material on PRNGs and add a short summary to Key Concepts.
emilyfertig Nov 15, 2024
efd2327
Merge pull request #24917 from emilyfertig:emilyaf-sharp-bits
Google-ML-Automation Nov 15, 2024
8525ef2
[sharding_in_types] Don't emit a wsc under full manual mode to avoid …
yashk2810 Nov 16, 2024
609dfac
Adds a flag to control proxy env checking.
yliu120 Nov 12, 2024
626aea0
Deduplicate constants in StableHLO lowering.
hawkinsp Nov 16, 2024
1d519f4
Return a ndarray in shape_as_value if the shape is known to be constant.
hawkinsp Nov 16, 2024
7b9914d
Update XLA dependency to use revision
Google-ML-Automation Nov 16, 2024
8a6c560
Use a direct StableHLO lowering for pow.
hawkinsp Nov 16, 2024
27bf80a
Adds an env that can let users provide a custom version suffix for ja…
yliu120 Nov 12, 2024
742cabc
Update XLA dependency to use revision
Google-ML-Automation Nov 17, 2024
ed250b8
[AutoPGLE] Temporary disable pgle_test in the OSS.
Google-ML-Automation Nov 18, 2024
f7ae0f9
Merge pull request #24930 from hawkinsp:dedup
Google-ML-Automation Nov 18, 2024
afdc792
Merge pull request #24933 from hawkinsp:pow
Google-ML-Automation Nov 18, 2024
ccb3317
Add a GPU implementation of `lax.linalg.eig`.
dfm Nov 18, 2024
65f9c78
Merge pull request #24932 from hawkinsp:gather
Google-ML-Automation Nov 18, 2024
1418739
Add new CI script for running Bazel GPU presubmits
nitins17 Nov 18, 2024
05d66d7
Merge pull request #24912 from jakevdp:jnp-module
Google-ML-Automation Nov 18, 2024
2de40e7
Merge pull request #24916 from jakevdp:update-lp
Google-ML-Automation Nov 18, 2024
e9864c6
Make logaddexp and logaddexp2 into ufuncs
jakevdp Nov 15, 2024
297a4e5
Merge pull request #24903 from jakevdp:logsumexp
Google-ML-Automation Nov 18, 2024
6fe7b17
Return SingleDeviceSharding instead of GSPMDShardings when there is o…
yashk2810 Nov 18, 2024
5bebd0f
fix typo in numpy/__init__.pyi
jakevdp Nov 18, 2024
70b05f6
Merge pull request #24952 from jakevdp:fix-pyi
Google-ML-Automation Nov 18, 2024
0ed6eae
[SDY] fix JAX layouts tests for Shardy.
Varcho Nov 18, 2024
461a250
Disable some complex function accuracy tests that fail on Mac ARM.
hawkinsp Nov 18, 2024
f325051
Filter custom dtypes by supported_dtypes in `_LazyDtypes`.
hawkinsp Nov 18, 2024
a60ef6e
[Pallas] Increase test coverage of pl.dot.
WindQAQ Nov 18, 2024
16ed283
Merge pull request #24957 from hawkinsp:arm
Google-ML-Automation Nov 18, 2024
b3ca6c4
Update XLA dependency to use revision
Google-ML-Automation Nov 18, 2024
91891cb
Merge pull request #23585 from apivovarov:float8_e4m3
Google-ML-Automation Nov 18, 2024
d4316b5
Adds font fallbacks
barnesjoseph Nov 18, 2024
6952ddf
Merge pull request #24958 from barnesjoseph:add-font-fallback
Google-ML-Automation Nov 18, 2024
e904c17
Delete _normalized_spec from NamedSharding
yashk2810 Nov 18, 2024
2c68569
Fix a bug where mesh checking was not correct
yashk2810 Nov 19, 2024
45c9c0a
[pallas] Minor simplifications to Pallas interpreter.
chr1sj0nes Nov 19, 2024
c5e8ae8
Update jax.scipy.special.gamma and gammasgn to return NaN for negativ…
hawkinsp Nov 18, 2024
4a9346e
Merge pull request #24945 from hawkinsp:gamma
Google-ML-Automation Nov 19, 2024
58103e5
Merge pull request #24861 from yliu120:add_versions
Google-ML-Automation Nov 19, 2024
0fe77bc
[Mosaic TPU] Support relayout for mask vector
bythew3i Nov 19, 2024
12a43f1
Merge pull request #24853 from yliu120:check_proxy_envs
Google-ML-Automation Nov 19, 2024
d397dd9
Implement lax.pad in Pallas.
Google-ML-Automation Nov 19, 2024
da50ad7
[AutoPGLE] Use compile options to override debug options instead of X…
Google-ML-Automation Nov 19, 2024
1458d3d
Fix some typos
nireekshak Nov 19, 2024
d912034
fix(docs): typos in macro name
jeertmans Nov 19, 2024
9d3eda1
Merge pull request #24942 from jeertmans:patch-1
Google-ML-Automation Nov 19, 2024
6929a97
Merge pull request #24968 from nireekshak:testingbranch
Google-ML-Automation Nov 19, 2024
3556a83
Add missing version guard in GPU tests for jnp.poly.
dfm Nov 19, 2024
6c31efa
[Mosaic TPU] Add general tpu.vector_store and support masked store.
bythew3i Nov 19, 2024
c44f11d
Add alternate implementation of threefry as a pallas kernel.
justinjfu Nov 19, 2024
a59bbb7
Add test utility for accessing jaxlib version tuple.
dfm Nov 19, 2024
2c80d1a
Add a new API jax.lax.split.
hawkinsp Nov 19, 2024
2075b09
Merge pull request #24970 from hawkinsp:split
Google-ML-Automation Nov 19, 2024
1bf70fb
[pallas:mosaic_gpu] `copy_gmem_to_smem` no longer requires `barrier` …
superbobry Nov 19, 2024
0d36b0b
[Mosaic] Add target core type parameter to tpu.sem_signal
naummo Nov 19, 2024
3161a28
Update XLA dependency to use revision
Google-ML-Automation Nov 19, 2024
42fbd30
Move JAX example to public XLA:CPU API
changm Nov 19, 2024
525b646
Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
hawkinsp Nov 19, 2024
c04aec9
[Mosaic] Extend tpu.sem_signal with subcore_id
naummo Nov 19, 2024
8c71d1a
Make deprecated jax.experimental.array_api module visibility internal…
Nov 20, 2024
867a361
Fix a bug where constant deduplication used an inappropriate inequality.
hawkinsp Nov 20, 2024
6c291d6
[Mosaic] Add `tpu.log` verification on SC
naummo Nov 20, 2024
4bb8107
represent `random.key_impl` of builtin RNGs by canonical string name
froystig Oct 29, 2024
4d60db1
Add test_compute_on_host_shared_sharding in memories_test
Google-ML-Automation Nov 20, 2024
ae46b75
Merge pull request #24593 from froystig:random-dtypes
Google-ML-Automation Nov 20, 2024
1afb05e
[mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise.
petebu Nov 20, 2024
14da7eb
[pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcas…
petebu Nov 20, 2024
c76e5fe
[pallas:mosaic_gpu] `copy_smem_to_gmem` now supports `wait_read_only`
superbobry Nov 20, 2024
f442d40
[mosaic_gpu] Fixed `FragmentedArray` comparisons with literals
superbobry Nov 20, 2024
04e4c69
[mosaic_gpu] Handle older `jaxlib`s in the profiler module
superbobry Nov 20, 2024
1df4b5f
[pallas] Do not skip vmap tests on GPU when x64 is enabled
superbobry Nov 20, 2024
a582df0
Update XLA dependency to use revision
Google-ML-Automation Nov 20, 2024
a4266b5
Mention python 3.13 in docs & package metadata
jakevdp Nov 20, 2024
1e9e85a
Simplify handling of `DotAlgorithmPreset` output types.
chr1sj0nes Nov 20, 2024
85e2969
Deprecate several private APIs in jax.lib
jakevdp Nov 20, 2024
800add2
Merge pull request #25007 from jakevdp:deps
Google-ML-Automation Nov 20, 2024
439d34d
Merge pull request #25005 from jakevdp:py313
Google-ML-Automation Nov 20, 2024
6222592
Fix KeyError recently introduced in cloud_tpu_init.py
skye Nov 20, 2024
8d84f28
[pallas mgpu] Lowering for while loops as long as they are secretly f…
cperivol Nov 20, 2024
d0f17c0
Make a direct linearize trace.
dougalm Nov 4, 2024
eab9026
Merge pull request #25004 from jax-ml:linearize-trace
Google-ML-Automation Nov 20, 2024
fee272e
Remove internal KeyArray alias
jakevdp Nov 20, 2024
2c9b917
Don't psum over auto mesh dims in _unmentioned2.
pschuh Nov 20, 2024
9584ee3
[pallas:mosaic_gpu] Avoid using multiple indexers in the parallel gri…
superbobry Nov 20, 2024
621e39d
Set __module__ attribute of jax.numpy.linalg APIs
jakevdp Nov 20, 2024
1a3e693
Merge pull request #25008 from skye:barrier
Google-ML-Automation Nov 20, 2024
dfe27a1
Mention stackless in the release notes.
hawkinsp Nov 20, 2024
19b4996
Merge pull request #25013 from hawkinsp:relnotes
Google-ML-Automation Nov 20, 2024
40fc659
[sharding_in_types] Make flash_attention forward pass in TPU pallas w…
yashk2810 Nov 20, 2024
9d2f62f
[Pallas TPU] Support masked store
bythew3i Nov 20, 2024
d219439
Merge pull request #25011 from jakevdp:jnp-linalg-module
Google-ML-Automation Nov 20, 2024
9b94180
[sharding_in_types] Add slice_p and squeeze_p sharding rule to make f…
yashk2810 Nov 20, 2024
6fe7804
Update XLA dependency to use revision
Google-ML-Automation Nov 20, 2024
f749fca
[array api] use most recent version of array_api_tests
jakevdp Nov 20, 2024
2699e95
DOC: add examples for jax.lax.pad
jakevdp Nov 20, 2024
334bd4d
Merge pull request #25019 from jakevdp:lax-pad-doc
Google-ML-Automation Nov 21, 2024
1782588
jax.lax.pad: improve input validation
jakevdp Nov 21, 2024
bf7f9aa
Adds Google Sans font
barnesjoseph Nov 21, 2024
f39392e
Merge pull request #25020 from jakevdp:lax-pad-validation
Google-ML-Automation Nov 21, 2024
1f6152d
[Pallas] Use Pallas cost estimator for flash attention.
justinjfu Nov 21, 2024
840cf3f
[sharding_in_types] Add `pad_p` support to sharding_in_types to handl…
yashk2810 Nov 21, 2024
869a533
[Mosaic TPU] Add bound check for general vector store op.
bythew3i Nov 21, 2024
6568713
[sharding_in_types] Add `concatenate_p` support
yashk2810 Nov 21, 2024
e72b449
Reverts c04aec9d525dd2e767495e41b98e82dd79315f37
naummo Nov 21, 2024
f18df8f
[pallas:mosaic_gpu] Pulled `delay_release` into `emit_pipeline`
superbobry Nov 21, 2024
1bc9df4
Integrate LLVM at llvm/llvm-project@33fcd6acc755
metaflow Nov 21, 2024
0831e2e
[shape_poly] Adding shape polymorphism support for the state primitives.
gnecula Nov 21, 2024
7d7a0fa
Run the TPU workflow on new self-hosted runners
nitins17 Nov 21, 2024
bf0150b
[JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculatin…
Google-ML-Automation Nov 21, 2024
7335267
Merge pull request #25015 from barnesjoseph:add-google-sans
Google-ML-Automation Nov 21, 2024
1e6654a
Fix cron schedule to run past minute 0 every 2nd hour
nitins17 Nov 21, 2024
b1b1ad6
Merge pull request #25018 from jakevdp:update-array-api
Google-ML-Automation Nov 21, 2024
1d2dc17
[mgpu] Pointwise op can handle LHS splats.
cperivol Nov 21, 2024
2178ed2
[pallas] Add more test cases for Triton bitcast_convert_type lowering…
petebu Nov 21, 2024
e707ede
Merge pull request #25034 from gnecula:poly_state
Google-ML-Automation Nov 21, 2024
96c0129
Fix false positive `debug_nans` error caused by NaNs that are properl…
dfm Nov 21, 2024
1efef6b
[pallas:mosaic_gpu] `emit_pipeline` now correctly supports `BlockSpec…
superbobry Nov 21, 2024
c1ae13b
Merge pull request #25009 from jakevdp:keyarray
Google-ML-Automation Nov 21, 2024
f3e7e68
Remove unneeded dependency from rocm_plugin_extension.
klucke Nov 21, 2024
f899d51
[Mosaic TPU] Fold sublane offset to indices when storing to untiled ref.
bythew3i Nov 21, 2024
26443bb
Update XLA dependency to use revision
Google-ML-Automation Nov 21, 2024
344d0d9
[Pallas] Add readme page for debugging tips.
justinjfu Nov 21, 2024
170718c
Change signature of linearization rules.
dougalm Nov 22, 2024
3d79df2
Merge pull request #25048 from jax-ml:linearization-rule-signature
Google-ML-Automation Nov 22, 2024
355589f
[sharding_in_types] Add scan support to sharding_in_types. There are …
yashk2810 Nov 22, 2024
a450bb0
Merge branch 'rocm-main' into ci-upstream-sync-34_1
charleshofer Nov 22, 2024
846697f
Longer timeout for doc render
charleshofer Nov 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/bazel_gpu_rbe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: CI - Bazel GPU tests (RBE)

on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
run_tests:
if: github.event.repository.fork == false
strategy:
matrix:
runner: ["linux-x86-n2-16"]

runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'

env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"

steps:
- uses: actions/checkout@v3
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel GPU Tests with RBE
run: ./ci/run_bazel_test_gpu_rbe.sh
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ jobs:
documentation_render:
name: Documentation - render documentation
runs-on: ubuntu-latest
timeout-minutes: 10
timeout-minutes: 20
strategy:
matrix:
python-version: ['3.10']
Expand Down
49 changes: 26 additions & 23 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
name: CI - Cloud TPU (nightly)
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
- cron: "0 */2 * * *" # Run every 2 hours
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
Expand All @@ -26,15 +26,18 @@ jobs:
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
{type: "v3-8", cores: "4"},
{type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
# {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
# {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240722
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
PYTHON: python${{ matrix.python-version }}
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 120
defaults:
run:
Expand All @@ -46,52 +49,52 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
pip install -U -r build/collect-profile-requirements.txt
$PYTHON -m pip install -U -r build/test-requirements.txt
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu
$PYTHON -m pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
pip install .[tpu] \
$PYTHON -m pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html

elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --pre libtpu \
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
$PYTHON -m pip install requests

elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
$PYTHON -m pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi

python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
python3 -c 'import jax; print("libtpu version:",
$PYTHON -c 'import sys; print("python version:", sys.version)'
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
$PYTHON -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻
ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
Expand Down
28 changes: 28 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.36

* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels, `post_process_call`,
`new_base_main`, `custom_bind`, and so on. The change should only affect
users that use JAX internals.

If you do use JAX internals then you may need to
update your code (see
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
you need help updating your code then please file a bug.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
Expand Down Expand Up @@ -43,6 +58,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.

* New Features
Expand All @@ -52,12 +70,22 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
* Added {func}`jax.numpy.put_along_axis`.
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
supported on GPU. See {jax-issue}`#24663` for more details.

* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`#24843` for more details.

* Deprecations
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
use `jax.Array` instead.
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
instead.

## jax 0.4.35 (Oct 22, 2024)

* Breaking Changes
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.

Using `jit` puts constraints on the kind of Python control flow
the function can use; see
the [Gotchas
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
for more.

### Auto-vectorization with `vmap`
Expand Down Expand Up @@ -349,7 +348,7 @@ Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package.
Expand All @@ -369,7 +368,7 @@ Some standouts:
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is `float64` rather than `float32`.
1. Some transformations, like `jit`, [constrain how you can use Python control
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
You'll always get loud errors if something goes wrong. You might have to use
[`jit`'s `static_argnums`
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
Expand All @@ -390,6 +389,7 @@ Some standouts:
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | yes | no | experimental | n/a | no | no |
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
| Intel GPU | experimental | n/a | n/a | n/a | no | no |


### Instructions
Expand All @@ -401,6 +401,7 @@ Some standouts:
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |

See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
for information on alternative installation strategies. These include compiling
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/shape_poly_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax import export

jax.config.parse_flags_with_absl()
Expand Down Expand Up @@ -76,7 +75,7 @@ def inequalities_slice(state):
while state:
for _ in range(30):
a.scope._clear_caches()
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
_ = 0 <= slice_size <= b
_ = start >= 0
_ = start + slice_size <= b
Expand Down
51 changes: 51 additions & 0 deletions ci/run_bazel_test_gpu_rbe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one
# GPU apiece on RBE.
#
# -e: abort script if one command fails
# -u: error if undefined variable used
# -x: log all commands
# -o history: record shell history
# -o allexport: export all functions and variables to be available to subscripts
set -exu -o history -o allexport

# Source default JAXCI environment variables.
source ci/envs/default.env

# Clone XLA at HEAD if path to local XLA is not provided
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
export JAXCI_CLONE_MAIN_XLA=1
fi

# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"

# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece).
echo "Running RBE GPU tests..."

bazel test --config=rbe_linux_x86_64_cuda \
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--test_output=errors \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
--test_tag_filters=-multiaccelerator \
--test_env=JAX_SKIP_SLOW_TESTS=true \
--action_env=JAX_ENABLE_X64=0 \
--color=yes \
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
6 changes: 3 additions & 3 deletions docs/Custom_Operation_for_GPUs.md
Original file line number Diff line number Diff line change
Expand Up @@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the
gradient. (And if you implement the interface to support vmat, it will
also be on the outer primitive).

JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic.
JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic.
XLA sharding goes in two phases: a sharding propagation phase and a partition phase.
The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph.
The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph.
For XLA to be able to shard our custom operations, it needs us to define 2 extra functions:
infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively.

The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding.

The partition() function will do a few things:
- tell which input sharding will be expected. XLA will reshad if needed.
- tell which input sharding will be expected. XLA will reshard if needed.
- tell the final version of the output sharding.
- give a function that will create the new instruction from the sharded inputs.

Expand Down
Loading