Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fast_gradient_method.py #1231

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cleverhans/jax/attacks/fast_gradient_method.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax.numpy as np
from jax import grad, vmap
from jax.experimental.stax import logsoftmax
from jax.example_libraries.stax import logsoftmax

from cleverhans.jax.utils import one_hot

Expand Down
3 changes: 2 additions & 1 deletion cleverhans/jax/attacks/projected_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def projected_gradient_descent(
targeted=False,
rand_init=None,
rand_minmax=0.3,
num_classes=10
):
"""
This class implements either the Basic Iterative Method
Expand Down Expand Up @@ -71,7 +72,7 @@ def projected_gradient_descent(
if y is None:
# Using model predictions as ground truth to avoid label leaking
x_labels = np.argmax(model_fn(x), 1)
y = one_hot(x_labels, 10)
y = one_hot(x_labels, num_classes)

for _ in range(nb_iter):
adv_x = fast_gradient_method(
Expand Down