Skip to content

Commit

Permalink
trax: avoid referencing deprecated jax.random.threefry_2x32
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575922997
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Oct 23, 2023
1 parent 2a4356a commit 38adb83
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion trax/layers/research/position_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import jax
import jax.extend as jex
import numpy as np
import trax
from trax import fastmath
Expand Down Expand Up @@ -291,7 +292,7 @@ def threefry_2x32_prf(key, x: jnp.ndarray) -> jnp.ndarray:
raise TypeError('x must be uint32[..., 2]', x)
# Threefry-2x32 expects this weird format:
x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten()
y_3f = jax.random.threefry_2x32(key, x_3f)
y_3f = jex.random.threefry_2x32(key, x_3f)
y = jnp.moveaxis(
jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1)
return y
Expand Down

0 comments on commit 38adb83

Please sign in to comment.