Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 11/22/24 upstream sync #148

Merged
merged 155 commits into from
Nov 22, 2024
Merged

CI: 11/22/24 upstream sync #148

merged 155 commits into from
Nov 22, 2024

Conversation

github-actions[bot]
Copy link

Daily sync with upstream

mini-goel and others added 30 commits October 28, 2024 10:47
…erally. If any axis is collective, set all dims of aval to unspecified dims in `wrap_with_sharding_op`.

Also lower shardings with `Collective` axes correctly to HloSharding.

PiperOrigin-RevId: 696703030
…reat 1D (N,) as (1, N) and then tile it as (1, 128)

PiperOrigin-RevId: 696870258
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
bythew3i and others added 24 commits November 20, 2024 17:28
PiperOrigin-RevId: 698654038
The implementation exactly matches the one we have in the lowering.

PiperOrigin-RevId: 698713343
Updates LLVM usage to match
[33fcd6acc755](llvm/llvm-project@33fcd6acc755)

PiperOrigin-RevId: 698742870
We are not able to run the TPU workflows because of no active runners (https://github.com/jax-ml/jax/actions/runs/11879479226/job/33101456081). So this adds the new self-hosted runners to the TPU workflow to fix this issue. The v3 type is disabled as we do not have that available yet.

PiperOrigin-RevId: 698772505
In the previous schedule, we were running at every minute at every 2nd hour.

PiperOrigin-RevId: 698804124
PiperOrigin-RevId: 698818035
…y handled in `jax.scipy.stats.gamma`

As reported in jax-ml#24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs.

Fixes jax-ml#24939

PiperOrigin-RevId: 698833589
…`s in GMEM

This is necessary to replace the pipelining logic in the lowering with
`emit_pipeline`.

PiperOrigin-RevId: 698858380
This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]).

PiperOrigin-RevId: 698896373
Give the rule the nonzero tangent pattern up-front. This is needed to make a
linearization rule for pjit_p. Also make the rules return the nonzero tangents
out, an explicit residual, and a closed tangent function. Add a rule for sin_p
to test it out. We still need to figure out how to avoid having to precompute
`cos(x)`. I think we need to update our backward pass code.
…a couple of changes here

* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)

* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.

* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.

* scan only allows `xs` where the 0th dim is full replicated i.e. None.

PiperOrigin-RevId: 699014167
@charleshofer charleshofer self-requested a review November 22, 2024 16:12
@charleshofer charleshofer merged commit 3be7c1e into rocm-main Nov 22, 2024
7 checks passed
@charleshofer charleshofer deleted the ci-upstream-sync-34_1 branch November 22, 2024 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.