-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1141 from michaelshiyu/cwl2
[MRG] CWL2 in PyTorch
- Loading branch information
Showing
5 changed files
with
359 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
"""The CarliniWagnerL2 attack.""" | ||
import torch | ||
|
||
|
||
INF = float('inf') | ||
def carlini_wagner_l2(model_fn, x, n_classes, | ||
y=None, | ||
targeted=False, | ||
lr=5e-3, | ||
confidence=0, | ||
clip_min=0, | ||
clip_max=1, | ||
initial_const=1e-2, | ||
binary_search_steps=5, | ||
max_iterations=1000, | ||
): | ||
""" | ||
This attack was originally proposed by Carlini and Wagner. It is an | ||
iterative attack that finds adversarial examples on many defenses that | ||
are robust to other attacks. | ||
Paper link: https://arxiv.org/abs/1608.04644 | ||
At a high level, this attack is an iterative attack using Adam and | ||
a specially-chosen loss function to find adversarial examples with | ||
lower distortion than other attacks. This comes at the cost of speed, | ||
as this attack is often much slower than others. | ||
:param model_fn: a callable that takes an input tensor and returns | ||
the model logits. The logits should be a tensor of shape | ||
(n_examples, n_classes). | ||
:param x: input tensor of shape (n_examples, ...), where ... can | ||
be any arbitrary dimension that is compatible with | ||
model_fn. | ||
:param n_classes: the number of classes. | ||
:param y: (optional) Tensor with true labels. If targeted is true, | ||
then provide the target label. Otherwise, only provide | ||
this parameter if you'd like to use true labels when | ||
crafting adversarial samples. Otherwise, model predictions | ||
are used as labels to avoid the "label leaking" effect | ||
(explained in this paper: | ||
https://arxiv.org/abs/1611.01236). If provide y, it | ||
should be a 1D tensor of shape (n_examples, ). | ||
Default is None. | ||
:param targeted: (optional) bool. Is the attack targeted or | ||
untargeted? Untargeted, the default, will try to make the | ||
label incorrect. Targeted will instead try to move in the | ||
direction of being more like y. | ||
:param lr: (optional) float. The learning rate for the attack | ||
algorithm. Default is 5e-3. | ||
:param confidence: (optional) float. Confidence of adversarial | ||
examples: higher produces examples with larger l2 | ||
distortion, but more strongly classified as adversarial. | ||
Default is 0. | ||
:param clip_min: (optional) float. Minimum float value for | ||
adversarial example components. Default is 0. | ||
:param clip_max: (optional) float. Maximum float value for | ||
adversarial example components. Default is 1. | ||
:param initial_const: The initial tradeoff-constant to use to tune the | ||
relative importance of size of the perturbation and | ||
confidence of classification. If binary_search_steps is | ||
large, the initial constant is not important. A smaller | ||
value of this constant gives lower distortion results. | ||
Default is 1e-2. | ||
:param binary_search_steps: (optional) int. The number of times we | ||
perform binary search to find the optimal tradeoff-constant | ||
between norm of the perturbation and confidence of the | ||
classification. Default is 5. | ||
:param max_iterations: (optional) int. The maximum number of | ||
iterations. Setting this to a larger value will produce | ||
lower distortion results. Using only a few iterations | ||
requires a larger learning rate, and will produce larger | ||
distortion results. Default is 1000. | ||
""" | ||
def compare(pred, label, is_logits=False): | ||
""" | ||
A helper function to compare prediction against a label. | ||
Returns true if the attack is considered successful. | ||
:param pred: can be either a 1D tensor of logits or a predicted | ||
class (int). | ||
:param label: int. A label to compare against. | ||
:param is_logits: (optional) bool. If True, treat pred as an | ||
array of logits. Default is False. | ||
""" | ||
|
||
# Convert logits to predicted class if necessary | ||
if is_logits: | ||
pred_copy = pred.clone().detach() | ||
pred_copy[label] += (-confidence if targeted else confidence) | ||
pred = torch.argmax(pred_copy) | ||
|
||
return pred == label if targeted else pred != label | ||
|
||
|
||
if y is None: | ||
# Using model predictions as ground truth to avoid label leaking | ||
pred = model_fn(x) | ||
y = torch.argmax(pred, 1) | ||
|
||
# Initialize some values needed for binary search on const | ||
lower_bound = [0.] * len(x) | ||
upper_bound = [1e10] * len(x) | ||
const = x.new_ones(len(x), 1) * initial_const | ||
|
||
o_bestl2 = [INF] * len(x) | ||
o_bestscore = [-1.] * len(x) | ||
x = torch.clamp(x, clip_min, clip_max) | ||
ox = x.clone().detach() # save the original x | ||
o_bestattack = x.clone().detach() | ||
|
||
# Map images into the tanh-space | ||
# TODO as of 01/06/2020, PyTorch does not natively support | ||
# arctanh (see, e.g., | ||
# https://github.com/pytorch/pytorch/issues/10324). | ||
# This particular implementation here is not numerically | ||
# stable and should be substituted w/ PyTorch's native | ||
# implementation when it comes out in the future | ||
arctanh = lambda x: 0.5 * torch.log((1 + x) / (1 - x)) | ||
x = (x - clip_min) / (clip_max - clip_min) | ||
x = torch.clamp(x, 0, 1) | ||
x = x * 2 - 1 | ||
x = arctanh(x * .999999) | ||
|
||
# Prepare some variables | ||
modifier = torch.zeros_like(x, requires_grad=True) | ||
y_onehot = torch.nn.functional.one_hot(y, n_classes).to(torch.float) | ||
|
||
# Define loss functions and optimizer | ||
f_fn = lambda real, other, targeted: torch.max( | ||
((other - real) if targeted else (real - other)) + confidence, | ||
torch.tensor(0.).to(real.device) | ||
) | ||
l2dist_fn = lambda x, y: torch.pow(x - y, 2).sum(list(range(len(x.size())))[1:]) | ||
optimizer = torch.optim.Adam([modifier], lr=lr) | ||
|
||
# Outer loop performing binary search on const | ||
for outer_step in range(binary_search_steps): | ||
# Initialize some values needed for the inner loop | ||
bestl2 = [INF] * len(x) | ||
bestscore = [-1.] * len(x) | ||
|
||
# Inner loop performing attack iterations | ||
for i in range(max_iterations): | ||
# One attack step | ||
new_x = (torch.tanh(modifier + x) + 1) / 2 | ||
new_x = new_x * (clip_max - clip_min) + clip_min | ||
logits = model_fn(new_x) | ||
|
||
real = torch.sum(y_onehot * logits, 1) | ||
other, _ = torch.max((1 - y_onehot) * logits - y_onehot * 1e4, 1) | ||
|
||
optimizer.zero_grad() | ||
f = f_fn(real, other, targeted) | ||
l2 = l2dist_fn(new_x, ox) | ||
loss = (const * f + l2).sum() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Update best results | ||
for n, (l2_n, logits_n, new_x_n) in enumerate(zip(l2, logits, new_x)): | ||
y_n = y[n] | ||
succeeded = compare(logits_n, y_n, is_logits=True) | ||
if l2_n < o_bestl2[n] and succeeded: | ||
pred_n = torch.argmax(logits_n) | ||
o_bestl2[n] = l2_n | ||
o_bestscore[n] = pred_n | ||
o_bestattack[n] = new_x_n | ||
# l2_n < o_bestl2[n] implies l2_n < bestl2[n] so we modify inner loop variables too | ||
bestl2[n] = l2_n | ||
bestscore[n] = pred_n | ||
elif l2_n < bestl2[n] and succeeded: | ||
bestl2[n] = l2_n | ||
bestscore[n] = torch.argmax(logits_n) | ||
|
||
# Binary search step | ||
for n in range(len(x)): | ||
y_n = y[n] | ||
|
||
if compare(bestscore[n], y_n) and bestscore[n] != -1: | ||
# Success, divide const by two | ||
upper_bound[n] = min(upper_bound[n], const[n]) | ||
if upper_bound[n] < 1e9: | ||
const[n] = (lower_bound[n] + upper_bound[n]) / 2 | ||
else: | ||
# Failure, either multiply by 10 if no solution found yet | ||
# or do binary search with the known upper bound | ||
lower_bound[n] = max(lower_bound[n], const[n]) | ||
if upper_bound[n] < 1e9: | ||
const[n] = (lower_bound[n] + upper_bound[n]) / 2 | ||
else: | ||
const[n] *= 10 | ||
|
||
return o_bestattack.detach() | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.clamp(torch.randn(5, 10), 0, 1) | ||
y = torch.randint(0, 9, (5,)) | ||
model_fn = lambda x: x | ||
|
||
# targeted | ||
new_x = carlini_wagner_l2(model_fn, x, 10, targeted=True, y=y) | ||
new_pred = model_fn(new_x) | ||
new_pred = torch.argmax(new_pred, 1) | ||
|
||
# untargeted | ||
new_x_untargeted = carlini_wagner_l2(model_fn, x, 10, targeted=False, y=y) | ||
new_pred_untargeted = model_fn(new_x_untargeted) | ||
new_pred_untargeted = torch.argmax(new_pred_untargeted, 1) |
Oops, something went wrong.