diff --git a/cleverhans/jax/attacks/fast_gradient_method.py b/cleverhans/jax/attacks/fast_gradient_method.py index 4fbc05cc5..11f479a42 100644 --- a/cleverhans/jax/attacks/fast_gradient_method.py +++ b/cleverhans/jax/attacks/fast_gradient_method.py @@ -1,6 +1,6 @@ import jax.numpy as np from jax import grad, vmap -from jax.experimental.stax import logsoftmax +from jax.example_libraries.stax import logsoftmax from cleverhans.jax.utils import one_hot diff --git a/cleverhans/jax/attacks/projected_gradient_descent.py b/cleverhans/jax/attacks/projected_gradient_descent.py index e5dbdde4f..a35091303 100644 --- a/cleverhans/jax/attacks/projected_gradient_descent.py +++ b/cleverhans/jax/attacks/projected_gradient_descent.py @@ -17,6 +17,7 @@ def projected_gradient_descent( targeted=False, rand_init=None, rand_minmax=0.3, + num_classes=10 ): """ This class implements either the Basic Iterative Method @@ -71,7 +72,7 @@ def projected_gradient_descent( if y is None: # Using model predictions as ground truth to avoid label leaking x_labels = np.argmax(model_fn(x), 1) - y = one_hot(x_labels, 10) + y = one_hot(x_labels, num_classes) for _ in range(nb_iter): adv_x = fast_gradient_method(