diff --git a/trax/layers/research/position_encodings.py b/trax/layers/research/position_encodings.py index fedbc3683..e9de3c04b 100644 --- a/trax/layers/research/position_encodings.py +++ b/trax/layers/research/position_encodings.py @@ -17,6 +17,7 @@ import logging import jax +import jax.extend as jex import numpy as np import trax from trax import fastmath @@ -291,7 +292,7 @@ def threefry_2x32_prf(key, x: jnp.ndarray) -> jnp.ndarray: raise TypeError('x must be uint32[..., 2]', x) # Threefry-2x32 expects this weird format: x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten() - y_3f = jax.random.threefry_2x32(key, x_3f) + y_3f = jex.random.threefry_2x32(key, x_3f) y = jnp.moveaxis( jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1) return y