Skip to content

Commit

Permalink
Add sqrt/square primitives to the inverse registry after openxla/stab…
Browse files Browse the repository at this point in the history
…lehlo#2623

PiperOrigin-RevId: 699910407
  • Loading branch information
hbq1 authored and DistraxDev committed Nov 25, 2024
1 parent 527d698 commit 0d922bd
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions distrax/_src/utils/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
jax.lax.integer_pow_p: lambda x, y: jax.lax.pow_p.bind(x, 1.0/y)
}

if jax.version.jax_version > "0.4.35":
_inverse_registry.update({
jax.lax.square_p: jax.lax.sqrt_p,
jax.lax.sqrt_p: jax.lax.square_p,
jax.lax.rsqrt_p: lambda x: 1.0 / jax.lax.square_p.bind(x),
})

_potentially_unstable_primitives = {
jax.lax.tanh_p: "distrax.Tanh or distrax.Inverse(distrax.Tanh)",
Expand Down

0 comments on commit 0d922bd

Please sign in to comment.