Skip to content

Commit

Permalink
Make FGM work for exotic dtypes. Some black-ing, pep8-ing. Fix broken…
Browse files Browse the repository at this point in the history
… test.
  • Loading branch information
kylematoba committed Dec 26, 2021
1 parent e5d00e5 commit acfd87e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
4 changes: 3 additions & 1 deletion cleverhans/torch/attacks/fast_gradient_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def fast_gradient_method(

# x needs to be a leaf variable, of floating point type and have requires_grad being True for
# its grad to be computed and stored properly in a backward call
x = x.clone().detach().to(torch.float).requires_grad_(True)
# kylematoba: probably don't need this, but I'll add it here to respect the intention of the earlier cast
assert torch.is_floating_point(x)
x = x.clone().detach().to(x.dtype).requires_grad_(True)
if y is None:
# Using model predictions as ground truth to avoid label leaking
_, y = torch.max(model_fn(x), 1)
Expand Down
10 changes: 5 additions & 5 deletions cleverhans/torch/attacks/hop_skip_jump_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def decision_function(images):
images = torch.clamp(images, clip_min, clip_max)
prob = []
for i in range(0, len(images), batch_size):
batch = images[i : i + batch_size]
batch = images[i: i + batch_size]
prob_i = model_fn(batch)
prob.append(prob_i)
prob = torch.cat(prob, dim=0)
Expand Down Expand Up @@ -214,7 +214,7 @@ def decision_function(images):


def compute_distance(x_ori, x_pert, constraint=2):
""" Compute the distance between two images. """
"""Compute the distance between two images."""
if constraint == 2:
dist = torch.norm(x_ori - x_pert, p=2)
elif constraint == np.inf:
Expand All @@ -225,7 +225,7 @@ def compute_distance(x_ori, x_pert, constraint=2):
def approximate_gradient(
decision_function, sample, num_evals, delta, constraint, shape, clip_min, clip_max
):
""" Gradient direction estimation """
"""Gradient direction estimation"""
# Generate random vectors.
noise_shape = [num_evals] + list(shape)
if constraint == 2:
Expand Down Expand Up @@ -260,7 +260,7 @@ def approximate_gradient(


def project(original_image, perturbed_images, alphas, shape, constraint):
""" Projection onto given l2 / linf balls in a batch. """
"""Projection onto given l2 / linf balls in a batch."""
alphas = alphas.view((alphas.shape[0],) + (1,) * (len(shape) - 1))
if constraint == 2:
projected = (1 - alphas) * original_image + alphas * perturbed_images
Expand All @@ -274,7 +274,7 @@ def project(original_image, perturbed_images, alphas, shape, constraint):
def binary_search_batch(
original_image, perturbed_images, decision_function, shape, constraint, theta
):
""" Binary search to approach the boundary. """
"""Binary search to approach the boundary."""

# Compute distance between each of perturbed image and original image.
dists_post_update = torch.stack(
Expand Down
2 changes: 2 additions & 0 deletions cleverhans/torch/tests/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def test_clips(self):
if norm == 1:
self.assertRaises(
NotImplementedError,
self.attack,
model_fn=self.model,
x=self.normalized_x,
eps=0.3,
Expand Down Expand Up @@ -1045,3 +1046,4 @@ def test_grad_sparsity_checks(self):
with self.assertRaises(ValueError) as context:
gs = torch.empty(101).uniform_(90, 99)
self.generate_adversarial_examples(sanity_checks=False, grad_sparsity=gs)

1 change: 0 additions & 1 deletion cleverhans/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def optimize_linear(grad, eps, norm=np.inf):
# Take sign of gradient
optimal_perturbation = torch.sign(grad)
elif norm == 1:
abs_grad = torch.abs(grad)
sign = torch.sign(grad)
red_ind = list(range(1, len(grad.size())))
abs_grad = torch.abs(grad)
Expand Down

0 comments on commit acfd87e

Please sign in to comment.