From 61de650af66a9e0f7268899148e8543ff89c5dbe Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Tue, 16 Jul 2019 21:56:34 -0700 Subject: [PATCH 1/7] Implement Deepfool with tests --- cleverhans/future/torch/attacks/__init__.py | 1 + cleverhans/future/torch/attacks/deepfool.py | 161 +++++++++++ cleverhans/future/torch/tests/test_attacks.py | 252 ++++++++++++++++++ cleverhans/future/torch/utils.py | 17 +- 4 files changed, 415 insertions(+), 16 deletions(-) create mode 100644 cleverhans/future/torch/attacks/deepfool.py diff --git a/cleverhans/future/torch/attacks/__init__.py b/cleverhans/future/torch/attacks/__init__.py index f411958ac..e10e47751 100644 --- a/cleverhans/future/torch/attacks/__init__.py +++ b/cleverhans/future/torch/attacks/__init__.py @@ -3,3 +3,4 @@ from cleverhans.future.torch.attacks.projected_gradient_descent import projected_gradient_descent from cleverhans.future.torch.attacks.noise import noise from cleverhans.future.torch.attacks.semantic import semantic +from cleverhans.future.torch.attacks.deepfool import deepfool diff --git a/cleverhans/future/torch/attacks/deepfool.py b/cleverhans/future/torch/attacks/deepfool.py new file mode 100644 index 000000000..f31e04411 --- /dev/null +++ b/cleverhans/future/torch/attacks/deepfool.py @@ -0,0 +1,161 @@ +"""The Deepfool attack.""" +import numpy as np +import torch +from torch.autograd.gradcheck import zero_gradients +from cleverhans.future.torch.utils import clip_eta + + +def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, + y=None, targeted=False, eps=None, norm=None, + num_classes=10, overshoot=0.02, max_iter=50, + is_debug=True, sanity_checks=False): + """ + PyTorch implementation of DeepFool (https://arxiv.org/pdf/1511.04599.pdf). + :param model_fn: A callable that takes an input tensor and returns the model logits. + :param x: Input tensor. + :param clip_min: If specified, the minimum input value. + :param clip_max: If specified, the maximum input value. + :param y: (optional) Tensor with true labels. If targeted is true, then provide the + target label. Otherwise, only provide this parameter if you'd like to use true + labels when crafting adversarial samples. Otherwise, model predictions are used + as labels to avoid the "label leaking" effect (explained in this paper: + https://arxiv.org/abs/1611.01236). Default is None. + :param targeted: (optional) bool. Is the attack targeted or untargeted? Untargeted, the + default, will try to make the label incorrect. Targeted will instead try to + move in the direction of being more like y. + :param eps: The size of the maximum perturbation, or None if the perturbation + should not be constrained. + :param norm: Order of the norm used for eps (mimics NumPy). Possible values: np.inf, 1 or 2. + :param num_classes: the attack targets this many of the closest classes in the untargeted + version. + :param overshoot: used as a termination criterion to prevent vanishing updates. + :param max_iter: maximum number of iterations for DeepFool. + :param is_debug: If True, print the success rate after each iteration. + :param sanity_checks: bool, if True, include asserts (Turn them off to use less runtime / + memory or for unit tests that intentionally pass strange input) + :return: a tensor for the adversarial example + """ + + if y is not None and len(x) != len(y): + raise ValueError('number of inputs {} is different from number of labels {}' + .format(len(x), len(y))) + if y is None: + if targeted: + raise ValueError('cannot perform a targeted attack without specifying targets y') + y = torch.argmax(model_fn(x), dim=1) + + if eps is not None: + if eps < 0: + raise ValueError( + "eps must be greater than or equal to 0, got {} instead".format(eps)) + if norm not in [np.inf, 1, 2]: + raise ValueError('invalid norm') + if eps == 0: + return x + + if clip_min is not None and clip_max is not None: + if clip_min > clip_max: + raise ValueError( + "clip_min must be less than or equal to clip_max, got clip_min={} and clip_max={}" + .format(clip_min, clip_max)) + + asserts = [] + + # If a data range was specified, check that the input was in that range + + asserts.append(torch.all(x >= clip_min)) + asserts.append(torch.all(x <= clip_max)) + + # Determine classes to target + if targeted: + target_classes = y[:, None] + y = torch.argmax(model_fn(x), dim=1) + else: + logits = model_fn(x) + logit_indices = torch.arange( + logits.size()[1], + dtype=y.dtype, + device=y.device, + )[None, :].expand(y.size()[0], -1) + # Number of target classes should be at most number of classes minus 1 + num_classes = min(num_classes, logits.size()[1] - 1) + incorrect_logits = torch.where( + logit_indices == y[:, None], + torch.full_like(logits, -np.inf), + logits, + ) + target_classes = incorrect_logits.argsort( + dim=1, descending=True)[:, :num_classes] + + x = x.clone().detach().to(torch.float) + perturbations = torch.zeros_like(x) + + if is_debug: + print("Starting DeepFool attack") + + for i in range(max_iter): + x_adv = x + (1 + overshoot) * perturbations + x_adv.requires_grad_(True) + zero_gradients(x_adv) + logits = model_fn(x_adv) + + # "Live" inputs are still being attacked; others have already achieved misclassification + if targeted: + live = torch.argmax(logits, dim=1) != target_classes[:, 0] + else: + live = torch.argmax(logits, dim=1) == y + if is_debug: + print('Iteration {}: {:.1f}% success'.format( + i, 100 * (1 - live.sum().float() / len(live)).item())) + if torch.all(~live): + # Stop early if all inputs are already misclassified + break + + smallest_magnitudes = torch.full((int(live.sum()),), np.inf, + dtype=torch.float, device=perturbations.device) + smallest_perturbation_updates = torch.zeros_like(perturbations[live]) + + logits[live, y[live]].sum().backward(retain_graph=True) + grads_correct = x_adv.grad.data[live].clone().detach() + + for k in range(target_classes.size()[1]): + zero_gradients(x_adv) + + logits_target = logits[live, target_classes[live, k]] + logits_target.sum().backward() + grads_target = x_adv.grad.data[live].clone().detach() + + grads_diff = (grads_target - grads_correct).detach() + logits_margin = (logits_target - logits[live, y[live]]).detach() + + grads_norm = grads_diff.norm(p=2, dim=list(range(1, len(grads_diff.size())))) + magnitudes = logits_margin.abs() / grads_norm + + magnitudes_expanded = magnitudes + for _ in range(len(grads_diff.size()) - 1): + grads_norm = grads_norm.unsqueeze(-1) + magnitudes_expanded = magnitudes_expanded.unsqueeze(-1) + perturbation_updates = ((magnitudes_expanded + 1e-4) * grads_diff / + grads_norm) + + smaller = magnitudes < smallest_magnitudes + smallest_perturbation_updates[smaller] = perturbation_updates[smaller] + smallest_magnitudes[smaller] = magnitudes[smaller] + + all_perturbation_updates = torch.zeros_like(perturbations) + all_perturbation_updates[live] = smallest_perturbation_updates + perturbations.add_(all_perturbation_updates) + + perturbations *= (1 + overshoot) + if eps is not None: + perturbations = clip_eta(perturbations, norm, eps) + + x_adv = torch.clamp(x + perturbations, clip_min, clip_max) + + asserts.append(torch.all(x_adv >= clip_min)) + asserts.append(torch.all(x_adv <= clip_max)) + + if sanity_checks: + assert np.all(asserts) + + return x_adv diff --git a/cleverhans/future/torch/tests/test_attacks.py b/cleverhans/future/torch/tests/test_attacks.py index 1f1722eee..47a0b9b36 100644 --- a/cleverhans/future/torch/tests/test_attacks.py +++ b/cleverhans/future/torch/tests/test_attacks.py @@ -5,12 +5,16 @@ from __future__ import unicode_literals import numpy as np +import copy from nose.plugins.skip import SkipTest import torch +from torch.autograd import Variable +from torch.autograd.gradcheck import zero_gradients from cleverhans.devtools.checks import CleverHansTest from cleverhans.future.torch.attacks.fast_gradient_method import fast_gradient_method from cleverhans.future.torch.attacks.projected_gradient_descent import projected_gradient_descent +from cleverhans.future.torch.attacks.deepfool import deepfool class SimpleModel(torch.nn.Module): @@ -25,6 +29,25 @@ def forward(self, x): x = torch.matmul(x, self.w2) return x + +class SimpleImageModel(torch.nn.Module): + + def __init__(self): + super(SimpleImageModel, self).__init__() + self.w1 = torch.tensor([[1.5, .3], [-2, .3]]) + self.w2 = torch.tensor([[-2.4, 1.2], [.5, -2.3]]) + + def forward(self, x): + if len(x.size()) == 4: + x = x[:, 0, 0] + elif len(x.size()) == 3: + x = x[None, 0, 0] + x = torch.matmul(x, self.w1) + x = torch.sigmoid(x) + x = torch.matmul(x, self.w2) + return x + + class CommonAttackProperties(CleverHansTest): def setUp(self): @@ -349,3 +372,232 @@ def test_multiple_initial_random_step(self): ori_label.eq(new_label_multi).sum().to(torch.float) / self.normalized_x.size(0)) self.assertLess(failed_attack, .5) + + +class TestDeepFool(CommonAttackProperties): + + def setUp(self): + super(TestDeepFool, self).setUp() + self.attack = deepfool + self.attack_param = { + 'eps' : .5, + 'clip_min' : -5, + 'clip_max' : 5, + } + + def test_invalid_input(self): + x = torch.tensor([[-2., 3.]]) + for norm in self.ord_list: + self.assertRaises( + AssertionError, self.attack, model_fn=self.model, x=x, eps=.1, + norm=norm, clip_min=-1., clip_max=1., sanity_checks=True) + + def test_invalid_eps(self): + for norm in self.ord_list: + self.assertRaises( + ValueError, self.attack, model_fn=self.model, + x=self.x, eps=-.1, norm=norm) + + def test_eps_equals_zero(self): + for norm in self.ord_list: + self.assertClose( + self.attack(model_fn=self.model, x=self.x, eps=0, norm=norm), + self.x) + + def test_max_iter_equals_zero(self): + for norm in self.ord_list: + self.assertClose( + self.attack( + model_fn=self.model, x=self.x, eps=.5, norm=norm, max_iter=0), + self.x) + + def test_invalid_clips(self): + clip_min = .5 + clip_max = -.5 + for norm in self.ord_list: + self.assertRaises( + ValueError, self.attack, model_fn=self.model, x=self.x, eps=.1, + norm=norm, clip_min=clip_min, clip_max=clip_max) + + def test_adv_example_success_rate_linf(self): + self.help_adv_examples_success_rate( + norm=np.inf, **self.attack_param) + + def test_targeted_adv_example_success_rate_linf(self): + self.help_targeted_adv_examples_success_rate( + norm=np.inf, **self.attack_param) + + def test_adv_example_success_rate_l1(self): + self.help_adv_examples_success_rate( + norm=1, **self.attack_param) + + def test_targeted_adv_example_success_rate_l1(self): + self.help_targeted_adv_examples_success_rate( + norm=1, **self.attack_param) + + def test_adv_example_success_rate_l2(self): + self.help_adv_examples_success_rate( + norm=2, **self.attack_param) + + def test_targeted_adv_example_success_rate_l2(self): + self.help_targeted_adv_examples_success_rate( + norm=2, **self.attack_param) + + def test_do_not_reach_lp_boundary(self): + for norm in self.ord_list: + x_adv = self.attack( + model_fn=self.model, x=self.normalized_x, eps=.5, norm=norm) + + if norm == np.inf: + delta, _ = torch.abs(x_adv - self.normalized_x).max(dim=1) + elif norm == 1: + delta = torch.abs(x_adv - self.normalized_x).sum(dim=1) + elif norm == 2: + delta = torch.pow(x_adv - self.normalized_x, 2).sum(dim=1).pow(.5) + diff = torch.max(.5 - delta) + self.assertTrue(diff > .25) + + def test_attack_strength(self): + x_adv = self.attack( + model_fn=self.model, x=self.normalized_x, + clip_min=0., clip_max=1., + sanity_checks=False) + _, ori_label = self.model(self.normalized_x).max(1) + _, adv_label = self.model(x_adv).max(1) + adv_acc = ( + adv_label.eq(ori_label).sum().to(torch.float) + / self.normalized_x.size(0)) + self.assertLess(adv_acc, .1) + + def test_eps(self): + # test if the attack respects the norm constraint + # NOTE clip_eta makes sure that at each step, adv_x respects the eps + # norm constraint. Therefore, this is essentially a test on clip_eta, + # which is implemented in a separate test_clip_eta + raise SkipTest() + + def test_clip_eta(self): + # NOTE: this has been tested with test_clip_eta in test_utils + raise SkipTest() + + def test_clips(self): + clip_min = -1. + clip_max = 1. + for norm in self.ord_list: + x_adv = self.attack( + model_fn=self.model, x=self.normalized_x, eps=.3, + norm=norm, clip_min=clip_min, clip_max=clip_max) + self.assertTrue(torch.all(x_adv <= clip_max)) + self.assertTrue(torch.all(x_adv >= clip_min)) + + def test_multiple_initial_random_step(self): + _, ori_label = self.model(self.normalized_x).max(1) + new_label_multi = ori_label.clone().detach() + + for _ in range(10): + x_adv = self.attack( + model_fn=self.model, x=self.normalized_x, eps=.5, + norm=np.inf, clip_min=.5, clip_max=.7, sanity_checks=False) + _, new_label = self.model(x_adv).max(1) + + # examples for which we have not found adversarial examples + i = ori_label.eq(new_label_multi) + new_label_multi[i] = new_label[i] + + failed_attack = ( + ori_label.eq(new_label_multi).sum().to(torch.float) + / self.normalized_x.size(0)) + self.assertLess(failed_attack, .5) + + def test_matches_reference(self): + model = SimpleImageModel() + for image in self.x: + image = image[None, None, :] + _, _, _, _, pert_image = TestDeepFool.reference_deepfool(image, model, num_classes=2) + self.assertClose( + self.attack(model_fn=model, x=image[None])[0], + pert_image) + + @staticmethod + def reference_deepfool(image, net, num_classes=10, overshoot=0.02, max_iter=50): + """ + Reference implementation of DeepFool from original authors at + https://github.com/LTS4/DeepFool. + :param image: Image of size HxWx3 + :param net: network (input: images, output: values of activation **BEFORE** softmax). + :param num_classes: num_classes (limits the number of classes to test against, by default = 10) + :param overshoot: used as a termination criterion to prevent vanishing updates (default = 0.02). + :param max_iter: maximum number of iterations for deepfool (default = 50) + :return: minimal perturbation that fools the classifier, number of iterations that it required, new estimated_label and perturbed image + """ + is_cuda = torch.cuda.is_available() + + if is_cuda: + print("Using GPU") + image = image.cuda() + net = net.cuda() + else: + print("Using CPU") + + f_image = net.forward(Variable( + image[None, :, :, :], requires_grad=True)).data.cpu().numpy().flatten() + I = (np.array(f_image)).flatten().argsort()[::-1] + + I = I[0:num_classes] + label = I[0] + + input_shape = image.cpu().numpy().shape + pert_image = copy.deepcopy(image) + w = np.zeros(input_shape) + r_tot = np.zeros(input_shape) + + loop_i = 0 + + x = Variable(pert_image[None, :], requires_grad=True) + fs = net.forward(x) + fs_list = [fs[0, I[k]] for k in range(num_classes)] + k_i = label + + while k_i == label and loop_i < max_iter: + + pert = np.inf + fs[0, I[0]].backward(retain_graph=True) + grad_orig = x.grad.data.cpu().numpy().copy() + + for k in range(1, num_classes): + zero_gradients(x) + + fs[0, I[k]].backward(retain_graph=True) + cur_grad = x.grad.data.cpu().numpy().copy() + + # set new w_k and new f_k + w_k = cur_grad - grad_orig + f_k = (fs[0, I[k]] - fs[0, I[0]]).data.cpu().numpy() + + pert_k = abs(f_k) / np.linalg.norm(w_k.flatten()) + + # determine which w_k to use + if pert_k < pert: + pert = pert_k + w = w_k + + # compute r_i and r_tot + # Added 1e-4 for numerical stability + r_i = (pert + 1e-4) * w / np.linalg.norm(w) + r_tot = np.float32(r_tot + r_i) + + if is_cuda: + pert_image = image + (1 + overshoot) * \ + torch.from_numpy(r_tot).cuda() + else: + pert_image = image + (1 + overshoot) * torch.from_numpy(r_tot) + + x = Variable(pert_image, requires_grad=True) + fs = net.forward(x) + k_i = np.argmax(fs.data.cpu().numpy().flatten()) + + loop_i += 1 + + r_tot = (1 + overshoot) * r_tot + + return r_tot, loop_i, label, k_i, pert_image diff --git a/cleverhans/future/torch/utils.py b/cleverhans/future/torch/utils.py index 9e33abea5..dff5923ed 100644 --- a/cleverhans/future/torch/utils.py +++ b/cleverhans/future/torch/utils.py @@ -21,22 +21,7 @@ def clip_eta(eta, norm, eps): if norm == np.inf: eta = torch.clamp(eta, -eps, eps) else: - if norm == 1: - raise NotImplementedError("L1 clip is not implemented.") - norm = torch.max( - avoid_zero_div, - torch.sum(torch.abs(eta), dim=reduc_ind, keepdim=True) - ) - elif norm == 2: - norm = torch.sqrt(torch.max( - avoid_zero_div, - torch.sum(eta ** 2, dim=reduc_ind, keepdim=True) - )) - factor = torch.min( - torch.tensor(1., dtype=eta.dtype, device=eta.device), - eps / norm - ) - eta *= factor + eta = torch.renorm(eta, p=norm, dim=0, maxnorm=eps) return eta def get_or_guess_labels(model, x, **kwargs): From b69a983711397d1e5e67c5493335e29e46a6763d Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Tue, 16 Jul 2019 22:36:14 -0700 Subject: [PATCH 2/7] Do further testing and fix issues with DeepFool implementation --- cleverhans/future/torch/attacks/deepfool.py | 4 +- cleverhans/future/torch/tests/test_attacks.py | 53 ++++++++++--------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/cleverhans/future/torch/attacks/deepfool.py b/cleverhans/future/torch/attacks/deepfool.py index f31e04411..f0742b156 100644 --- a/cleverhans/future/torch/attacks/deepfool.py +++ b/cleverhans/future/torch/attacks/deepfool.py @@ -8,7 +8,7 @@ def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, y=None, targeted=False, eps=None, norm=None, num_classes=10, overshoot=0.02, max_iter=50, - is_debug=True, sanity_checks=False): + is_debug=False, sanity_checks=False): """ PyTorch implementation of DeepFool (https://arxiv.org/pdf/1511.04599.pdf). :param model_fn: A callable that takes an input tensor and returns the model logits. @@ -122,7 +122,7 @@ def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, zero_gradients(x_adv) logits_target = logits[live, target_classes[live, k]] - logits_target.sum().backward() + logits_target.sum().backward(retain_graph=True) grads_target = x_adv.grad.data[live].clone().detach() grads_diff = (grads_target - grads_correct).detach() diff --git a/cleverhans/future/torch/tests/test_attacks.py b/cleverhans/future/torch/tests/test_attacks.py index 47a0b9b36..6dfea9719 100644 --- a/cleverhans/future/torch/tests/test_attacks.py +++ b/cleverhans/future/torch/tests/test_attacks.py @@ -10,6 +10,8 @@ import torch from torch.autograd import Variable from torch.autograd.gradcheck import zero_gradients +import torch.nn.functional as F +from torch import nn from cleverhans.devtools.checks import CleverHansTest from cleverhans.future.torch.attacks.fast_gradient_method import fast_gradient_method @@ -30,24 +32,6 @@ def forward(self, x): return x -class SimpleImageModel(torch.nn.Module): - - def __init__(self): - super(SimpleImageModel, self).__init__() - self.w1 = torch.tensor([[1.5, .3], [-2, .3]]) - self.w2 = torch.tensor([[-2.4, 1.2], [.5, -2.3]]) - - def forward(self, x): - if len(x.size()) == 4: - x = x[:, 0, 0] - elif len(x.size()) == 3: - x = x[None, 0, 0] - x = torch.matmul(x, self.w1) - x = torch.sigmoid(x) - x = torch.matmul(x, self.w2) - return x - - class CommonAttackProperties(CleverHansTest): def setUp(self): @@ -374,13 +358,35 @@ def test_multiple_initial_random_step(self): self.assertLess(failed_attack, .5) +class SimpleImageModel(torch.nn.Module): + """ + This slightly more complex model is useful for testing Deepfool. It has + two full-connected layers (one hidden) with ReLU activations between and outputs + five classes. + """ + + def __init__(self): + super(SimpleImageModel, self).__init__() + self.l1 = nn.Linear(2, 10) + self.l2 = nn.Linear(10, 5) + + def forward(self, x): + if len(x.size()) == 4: + x = x[:, 0, 0] + elif len(x.size()) == 3: + x = x[None, 0, 0] + x = self.l1(x) + x = F.relu(x) + x = self.l2(x) + return x + + class TestDeepFool(CommonAttackProperties): def setUp(self): super(TestDeepFool, self).setUp() self.attack = deepfool self.attack_param = { - 'eps' : .5, 'clip_min' : -5, 'clip_max' : 5, } @@ -511,12 +517,11 @@ def test_multiple_initial_random_step(self): def test_matches_reference(self): model = SimpleImageModel() - for image in self.x: + x_adv = self.attack(model_fn=model, x=self.x[:, None, None, :]) + for image, adv_image in zip(self.x, x_adv): image = image[None, None, :] - _, _, _, _, pert_image = TestDeepFool.reference_deepfool(image, model, num_classes=2) - self.assertClose( - self.attack(model_fn=model, x=image[None])[0], - pert_image) + _, _, _, _, pert_image = TestDeepFool.reference_deepfool(image, model, num_classes=5) + assert torch.norm(adv_image - pert_image) < 1e-4, (adv_image, pert_image) @staticmethod def reference_deepfool(image, net, num_classes=10, overshoot=0.02, max_iter=50): From f1969b1e0fb1d90e41fb54563c00e40867d3878c Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Tue, 16 Jul 2019 23:16:28 -0700 Subject: [PATCH 3/7] Apply norm constraint and clipping after each iteration in DeepFool --- cleverhans/future/torch/attacks/deepfool.py | 16 +++++++++++----- cleverhans/future/torch/tests/test_attacks.py | 4 ++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/cleverhans/future/torch/attacks/deepfool.py b/cleverhans/future/torch/attacks/deepfool.py index f0742b156..0d5cd9ee4 100644 --- a/cleverhans/future/torch/attacks/deepfool.py +++ b/cleverhans/future/torch/attacks/deepfool.py @@ -146,11 +146,17 @@ def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, all_perturbation_updates[live] = smallest_perturbation_updates perturbations.add_(all_perturbation_updates) - perturbations *= (1 + overshoot) - if eps is not None: - perturbations = clip_eta(perturbations, norm, eps) - - x_adv = torch.clamp(x + perturbations, clip_min, clip_max) + perturbations *= (1 + overshoot) + if eps is not None: + perturbations = clip_eta(perturbations, norm, eps) + perturbations = torch.clamp(x + perturbations, clip_min, clip_max) - x + perturbations /= (1 + overshoot) + +# perturbations *= (1 + overshoot) +# if eps is not None: +# perturbations = clip_eta(perturbations, norm, eps) +# + x_adv = x + perturbations * (1 + overshoot) asserts.append(torch.all(x_adv >= clip_min)) asserts.append(torch.all(x_adv <= clip_max)) diff --git a/cleverhans/future/torch/tests/test_attacks.py b/cleverhans/future/torch/tests/test_attacks.py index 6dfea9719..5244ac7c3 100644 --- a/cleverhans/future/torch/tests/test_attacks.py +++ b/cleverhans/future/torch/tests/test_attacks.py @@ -465,8 +465,8 @@ def test_do_not_reach_lp_boundary(self): def test_attack_strength(self): x_adv = self.attack( - model_fn=self.model, x=self.normalized_x, - clip_min=0., clip_max=1., + model_fn=self.model, x=self.normalized_x, eps=1., + norm=np.inf, clip_min=.5, clip_max=.7, sanity_checks=False) _, ori_label = self.model(self.normalized_x).max(1) _, adv_label = self.model(x_adv).max(1) From 6273641f0341c8d7df8619a54422340bcd30327b Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Tue, 16 Jul 2019 23:31:53 -0700 Subject: [PATCH 4/7] Added L_inf-specific perturbation updates from paper --- cleverhans/future/torch/attacks/deepfool.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cleverhans/future/torch/attacks/deepfool.py b/cleverhans/future/torch/attacks/deepfool.py index 0d5cd9ee4..d16fa5d28 100644 --- a/cleverhans/future/torch/attacks/deepfool.py +++ b/cleverhans/future/torch/attacks/deepfool.py @@ -128,15 +128,21 @@ def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, grads_diff = (grads_target - grads_correct).detach() logits_margin = (logits_target - logits[live, y[live]]).detach() - grads_norm = grads_diff.norm(p=2, dim=list(range(1, len(grads_diff.size())))) + grads_norm = grads_diff.norm(p=1 if norm == np.inf else 2, + dim=list(range(1, len(grads_diff.size())))) magnitudes = logits_margin.abs() / grads_norm magnitudes_expanded = magnitudes for _ in range(len(grads_diff.size()) - 1): grads_norm = grads_norm.unsqueeze(-1) magnitudes_expanded = magnitudes_expanded.unsqueeze(-1) - perturbation_updates = ((magnitudes_expanded + 1e-4) * grads_diff / - grads_norm) + + if norm == np.inf: + perturbation_updates = ((magnitudes_expanded + 1e-4) * + torch.sign(grads_diff)) + else: + perturbation_updates = ((magnitudes_expanded + 1e-4) * grads_diff / + grads_norm) smaller = magnitudes < smallest_magnitudes smallest_perturbation_updates[smaller] = perturbation_updates[smaller] From f977af8f97ae6a3a579eab95d5060152928bd7f7 Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Tue, 16 Jul 2019 23:49:53 -0700 Subject: [PATCH 5/7] Remove dependency on newer PyTorch feature to hopefully fix Travis errors --- cleverhans/future/torch/attacks/deepfool.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cleverhans/future/torch/attacks/deepfool.py b/cleverhans/future/torch/attacks/deepfool.py index d16fa5d28..6186549bd 100644 --- a/cleverhans/future/torch/attacks/deepfool.py +++ b/cleverhans/future/torch/attacks/deepfool.py @@ -128,8 +128,10 @@ def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, grads_diff = (grads_target - grads_correct).detach() logits_margin = (logits_target - logits[live, y[live]]).detach() - grads_norm = grads_diff.norm(p=1 if norm == np.inf else 2, - dim=list(range(1, len(grads_diff.size())))) + p = 1 if norm == np.inf else 2 + + grads_norm = (grads_diff ** p).abs().sum(dim=list(range(1, len(grads_diff.size())))) \ + ** (1. / p) magnitudes = logits_margin.abs() / grads_norm magnitudes_expanded = magnitudes From 11bfdfd3ca0d2eb351a45bd4bd6cdbb8b781ac55 Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Wed, 17 Jul 2019 01:15:32 -0700 Subject: [PATCH 6/7] Fix test_utils now that L_1 norm is implemented for clip_eta --- cleverhans/future/torch/tests/test_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cleverhans/future/torch/tests/test_utils.py b/cleverhans/future/torch/tests/test_utils.py index a4ee73150..f42b4cf43 100644 --- a/cleverhans/future/torch/tests/test_utils.py +++ b/cleverhans/future/torch/tests/test_utils.py @@ -94,13 +94,9 @@ def test_clip_eta_linf(self): self.assertTrue(torch.all(clipped >= -.5)) def test_clip_eta_l1(self): - self.assertRaises( - NotImplementedError, self.clip_eta, eta=self.rand_eta, norm=1, eps=.5) - - # TODO uncomment the actual test below after we have implemented the L1 attack - # clipped = self.clip_eta(eta=self.rand_eta, norm=1, eps=.5) - # norm = clipped.abs().sum(dim=self.red_ind) - # self.assertTrue(torch.all(norm <= .5001)) + clipped = self.clip_eta(eta=self.rand_eta, norm=1, eps=.5) + norm = clipped.abs().sum(dim=self.red_ind) + self.assertTrue(torch.all(norm <= .5001)) def test_clip_eta_l2(self): clipped = self.clip_eta(eta=self.rand_eta, norm=2, eps=.5) From 973871a77939b5276b774459beb6e58f42665a8c Mon Sep 17 00:00:00 2001 From: Cassidy Laidlaw Date: Thu, 18 Jul 2019 19:35:04 -0700 Subject: [PATCH 7/7] Remove unnecessary code in DeepFool implementation --- cleverhans/future/torch/attacks/deepfool.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cleverhans/future/torch/attacks/deepfool.py b/cleverhans/future/torch/attacks/deepfool.py index 6186549bd..abcceb2b9 100644 --- a/cleverhans/future/torch/attacks/deepfool.py +++ b/cleverhans/future/torch/attacks/deepfool.py @@ -160,10 +160,6 @@ def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, perturbations = torch.clamp(x + perturbations, clip_min, clip_max) - x perturbations /= (1 + overshoot) -# perturbations *= (1 + overshoot) -# if eps is not None: -# perturbations = clip_eta(perturbations, norm, eps) -# x_adv = x + perturbations * (1 + overshoot) asserts.append(torch.all(x_adv >= clip_min))