forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CI: 11/22/24 upstream sync #148
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PiperOrigin-RevId: 696602883
PiperOrigin-RevId: 696623038
http://github.com/openxla/xla/commit/ecdba3f23b20e684c5e67a5ddb4f004de724f6df. PiperOrigin-RevId: 696642961
…plugin PiperOrigin-RevId: 696677564
PiperOrigin-RevId: 696679346
PiperOrigin-RevId: 696679735
PiperOrigin-RevId: 696681818
PiperOrigin-RevId: 696688602
…-tutorial PiperOrigin-RevId: 696692588
…erally. If any axis is collective, set all dims of aval to unspecified dims in `wrap_with_sharding_op`. Also lower shardings with `Collective` axes correctly to HloSharding. PiperOrigin-RevId: 696703030
…reat 1D (N,) as (1, N) and then tile it as (1, 128) PiperOrigin-RevId: 696870258
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
PiperOrigin-RevId: 696915844
http://github.com/openxla/xla/commit/195f45b7082930033f6533a160b0f8f7f1cbfb40. PiperOrigin-RevId: 696984108
Co-authored-by: Jake VanderPlas <[email protected]>
…xample PiperOrigin-RevId: 696989966
PiperOrigin-RevId: 696989967
PiperOrigin-RevId: 698577015
PiperOrigin-RevId: 698621325
The implementation exactly matches the one we have in the lowering. PiperOrigin-RevId: 698713343
Updates LLVM usage to match [33fcd6acc755](llvm/llvm-project@33fcd6acc755) PiperOrigin-RevId: 698742870
We are not able to run the TPU workflows because of no active runners (https://github.com/jax-ml/jax/actions/runs/11879479226/job/33101456081). So this adds the new self-hosted runners to the TPU workflow to fix this issue. The v3 type is disabled as we do not have that available yet. PiperOrigin-RevId: 698772505
…g module hash. PiperOrigin-RevId: 698789020
PiperOrigin-RevId: 698798602
In the previous schedule, we were running at every minute at every 2nd hour. PiperOrigin-RevId: 698804124
PiperOrigin-RevId: 698811575
PiperOrigin-RevId: 698818035
… rule. PiperOrigin-RevId: 698818103
PiperOrigin-RevId: 698820458
…y handled in `jax.scipy.stats.gamma` As reported in jax-ml#24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs. Fixes jax-ml#24939 PiperOrigin-RevId: 698833589
…`s in GMEM This is necessary to replace the pipelining logic in the lowering with `emit_pipeline`. PiperOrigin-RevId: 698858380
PiperOrigin-RevId: 698865655
PiperOrigin-RevId: 698872849
This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]). PiperOrigin-RevId: 698896373
http://github.com/openxla/xla/commit/85360d67ffc0a6d6923605b848de12ec204ca336. PiperOrigin-RevId: 698915433
PiperOrigin-RevId: 698939951
Give the rule the nonzero tangent pattern up-front. This is needed to make a linearization rule for pjit_p. Also make the rules return the nonzero tangents out, an explicit residual, and a closed tangent function. Add a rule for sin_p to test it out. We still need to figure out how to avoid having to precompute `cos(x)`. I think we need to update our backward pass code.
PiperOrigin-RevId: 699007033
…a couple of changes here * Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path) * Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager. * Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them. * scan only allows `xs` where the 0th dim is full replicated i.e. None. PiperOrigin-RevId: 699014167
charleshofer
approved these changes
Nov 22, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream