diff --git a/tutorials/jax/mnist_tutorial.py b/tutorials/jax/mnist_tutorial.py index 6b510c9fe..497434449 100644 --- a/tutorials/jax/mnist_tutorial.py +++ b/tutorials/jax/mnist_tutorial.py @@ -6,9 +6,9 @@ import jax.numpy as np import numpy.random as npr from jax import jit, grad, random -from jax.experimental import optimizers -from jax.experimental import stax -from jax.experimental.stax import logsoftmax +from jax.example_libraries import optimizers +from jax.example_libraries import stax +from jax.example_libraries.stax import logsoftmax from cleverhans.jax.attacks.fast_gradient_method import fast_gradient_method from cleverhans.jax.attacks.projected_gradient_descent import projected_gradient_descent