From 26a2b393412f39fa41dd6976aec2222135da1095 Mon Sep 17 00:00:00 2001 From: Peter Luo Date: Sun, 9 Oct 2022 20:53:55 -0400 Subject: [PATCH 1/2] Update fast_gradient_method.py --- cleverhans/jax/attacks/fast_gradient_method.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleverhans/jax/attacks/fast_gradient_method.py b/cleverhans/jax/attacks/fast_gradient_method.py index 4fbc05cc5..11f479a42 100644 --- a/cleverhans/jax/attacks/fast_gradient_method.py +++ b/cleverhans/jax/attacks/fast_gradient_method.py @@ -1,6 +1,6 @@ import jax.numpy as np from jax import grad, vmap -from jax.experimental.stax import logsoftmax +from jax.example_libraries.stax import logsoftmax from cleverhans.jax.utils import one_hot From b82336d30a5aef6ad67e4bae488b0db23af2a4fc Mon Sep 17 00:00:00 2001 From: Peter Luo Date: Sun, 23 Oct 2022 02:06:21 -0400 Subject: [PATCH 2/2] Update projected_gradient_descent.py --- cleverhans/jax/attacks/projected_gradient_descent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cleverhans/jax/attacks/projected_gradient_descent.py b/cleverhans/jax/attacks/projected_gradient_descent.py index e5dbdde4f..a35091303 100644 --- a/cleverhans/jax/attacks/projected_gradient_descent.py +++ b/cleverhans/jax/attacks/projected_gradient_descent.py @@ -17,6 +17,7 @@ def projected_gradient_descent( targeted=False, rand_init=None, rand_minmax=0.3, + num_classes=10 ): """ This class implements either the Basic Iterative Method @@ -71,7 +72,7 @@ def projected_gradient_descent( if y is None: # Using model predictions as ground truth to avoid label leaking x_labels = np.argmax(model_fn(x), 1) - y = one_hot(x_labels, 10) + y = one_hot(x_labels, num_classes) for _ in range(nb_iter): adv_x = fast_gradient_method(