Skip to content

Commit

Permalink
Merge pull request #1141 from michaelshiyu/cwl2
Browse files Browse the repository at this point in the history
[MRG] CWL2 in PyTorch
  • Loading branch information
jonasguan authored Jan 25, 2021
2 parents 68d6bab + d6417be commit 50a579b
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ install:

- time pip install -q -e ".[test]"
- if [[ "$PYTORCH" == True ]]; then
pip install torch==1.0.1.post2 torchvision==0.2.2 -q;
pip install torch==1.1.0 torchvision==0.3.0 -q;
fi
# workaround for version incompatibility between the scipy version in conda
# and the system-provided /usr/lib/x86_64-linux-gnu/libstdc++.so.6
Expand Down
1 change: 1 addition & 0 deletions cleverhans/future/torch/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from cleverhans.future.torch.attacks.semantic import semantic
from cleverhans.future.torch.attacks.spsa import spsa
from cleverhans.future.torch.attacks.hop_skip_jump_attack import hop_skip_jump_attack
from cleverhans.future.torch.attacks.carlini_wagner_l2 import carlini_wagner_l2
209 changes: 209 additions & 0 deletions cleverhans/future/torch/attacks/carlini_wagner_l2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""The CarliniWagnerL2 attack."""
import torch


INF = float('inf')
def carlini_wagner_l2(model_fn, x, n_classes,
y=None,
targeted=False,
lr=5e-3,
confidence=0,
clip_min=0,
clip_max=1,
initial_const=1e-2,
binary_search_steps=5,
max_iterations=1000,
):
"""
This attack was originally proposed by Carlini and Wagner. It is an
iterative attack that finds adversarial examples on many defenses that
are robust to other attacks.
Paper link: https://arxiv.org/abs/1608.04644
At a high level, this attack is an iterative attack using Adam and
a specially-chosen loss function to find adversarial examples with
lower distortion than other attacks. This comes at the cost of speed,
as this attack is often much slower than others.
:param model_fn: a callable that takes an input tensor and returns
the model logits. The logits should be a tensor of shape
(n_examples, n_classes).
:param x: input tensor of shape (n_examples, ...), where ... can
be any arbitrary dimension that is compatible with
model_fn.
:param n_classes: the number of classes.
:param y: (optional) Tensor with true labels. If targeted is true,
then provide the target label. Otherwise, only provide
this parameter if you'd like to use true labels when
crafting adversarial samples. Otherwise, model predictions
are used as labels to avoid the "label leaking" effect
(explained in this paper:
https://arxiv.org/abs/1611.01236). If provide y, it
should be a 1D tensor of shape (n_examples, ).
Default is None.
:param targeted: (optional) bool. Is the attack targeted or
untargeted? Untargeted, the default, will try to make the
label incorrect. Targeted will instead try to move in the
direction of being more like y.
:param lr: (optional) float. The learning rate for the attack
algorithm. Default is 5e-3.
:param confidence: (optional) float. Confidence of adversarial
examples: higher produces examples with larger l2
distortion, but more strongly classified as adversarial.
Default is 0.
:param clip_min: (optional) float. Minimum float value for
adversarial example components. Default is 0.
:param clip_max: (optional) float. Maximum float value for
adversarial example components. Default is 1.
:param initial_const: The initial tradeoff-constant to use to tune the
relative importance of size of the perturbation and
confidence of classification. If binary_search_steps is
large, the initial constant is not important. A smaller
value of this constant gives lower distortion results.
Default is 1e-2.
:param binary_search_steps: (optional) int. The number of times we
perform binary search to find the optimal tradeoff-constant
between norm of the perturbation and confidence of the
classification. Default is 5.
:param max_iterations: (optional) int. The maximum number of
iterations. Setting this to a larger value will produce
lower distortion results. Using only a few iterations
requires a larger learning rate, and will produce larger
distortion results. Default is 1000.
"""
def compare(pred, label, is_logits=False):
"""
A helper function to compare prediction against a label.
Returns true if the attack is considered successful.
:param pred: can be either a 1D tensor of logits or a predicted
class (int).
:param label: int. A label to compare against.
:param is_logits: (optional) bool. If True, treat pred as an
array of logits. Default is False.
"""

# Convert logits to predicted class if necessary
if is_logits:
pred_copy = pred.clone().detach()
pred_copy[label] += (-confidence if targeted else confidence)
pred = torch.argmax(pred_copy)

return pred == label if targeted else pred != label


if y is None:
# Using model predictions as ground truth to avoid label leaking
pred = model_fn(x)
y = torch.argmax(pred, 1)

# Initialize some values needed for binary search on const
lower_bound = [0.] * len(x)
upper_bound = [1e10] * len(x)
const = x.new_ones(len(x), 1) * initial_const

o_bestl2 = [INF] * len(x)
o_bestscore = [-1.] * len(x)
x = torch.clamp(x, clip_min, clip_max)
ox = x.clone().detach() # save the original x
o_bestattack = x.clone().detach()

# Map images into the tanh-space
# TODO as of 01/06/2020, PyTorch does not natively support
# arctanh (see, e.g.,
# https://github.com/pytorch/pytorch/issues/10324).
# This particular implementation here is not numerically
# stable and should be substituted w/ PyTorch's native
# implementation when it comes out in the future
arctanh = lambda x: 0.5 * torch.log((1 + x) / (1 - x))
x = (x - clip_min) / (clip_max - clip_min)
x = torch.clamp(x, 0, 1)
x = x * 2 - 1
x = arctanh(x * .999999)

# Prepare some variables
modifier = torch.zeros_like(x, requires_grad=True)
y_onehot = torch.nn.functional.one_hot(y, n_classes).to(torch.float)

# Define loss functions and optimizer
f_fn = lambda real, other, targeted: torch.max(
((other - real) if targeted else (real - other)) + confidence,
torch.tensor(0.).to(real.device)
)
l2dist_fn = lambda x, y: torch.pow(x - y, 2).sum(list(range(len(x.size())))[1:])
optimizer = torch.optim.Adam([modifier], lr=lr)

# Outer loop performing binary search on const
for outer_step in range(binary_search_steps):
# Initialize some values needed for the inner loop
bestl2 = [INF] * len(x)
bestscore = [-1.] * len(x)

# Inner loop performing attack iterations
for i in range(max_iterations):
# One attack step
new_x = (torch.tanh(modifier + x) + 1) / 2
new_x = new_x * (clip_max - clip_min) + clip_min
logits = model_fn(new_x)

real = torch.sum(y_onehot * logits, 1)
other, _ = torch.max((1 - y_onehot) * logits - y_onehot * 1e4, 1)

optimizer.zero_grad()
f = f_fn(real, other, targeted)
l2 = l2dist_fn(new_x, ox)
loss = (const * f + l2).sum()
loss.backward()
optimizer.step()

# Update best results
for n, (l2_n, logits_n, new_x_n) in enumerate(zip(l2, logits, new_x)):
y_n = y[n]
succeeded = compare(logits_n, y_n, is_logits=True)
if l2_n < o_bestl2[n] and succeeded:
pred_n = torch.argmax(logits_n)
o_bestl2[n] = l2_n
o_bestscore[n] = pred_n
o_bestattack[n] = new_x_n
# l2_n < o_bestl2[n] implies l2_n < bestl2[n] so we modify inner loop variables too
bestl2[n] = l2_n
bestscore[n] = pred_n
elif l2_n < bestl2[n] and succeeded:
bestl2[n] = l2_n
bestscore[n] = torch.argmax(logits_n)

# Binary search step
for n in range(len(x)):
y_n = y[n]

if compare(bestscore[n], y_n) and bestscore[n] != -1:
# Success, divide const by two
upper_bound[n] = min(upper_bound[n], const[n])
if upper_bound[n] < 1e9:
const[n] = (lower_bound[n] + upper_bound[n]) / 2
else:
# Failure, either multiply by 10 if no solution found yet
# or do binary search with the known upper bound
lower_bound[n] = max(lower_bound[n], const[n])
if upper_bound[n] < 1e9:
const[n] = (lower_bound[n] + upper_bound[n]) / 2
else:
const[n] *= 10

return o_bestattack.detach()


if __name__ == '__main__':
x = torch.clamp(torch.randn(5, 10), 0, 1)
y = torch.randint(0, 9, (5,))
model_fn = lambda x: x

# targeted
new_x = carlini_wagner_l2(model_fn, x, 10, targeted=True, y=y)
new_pred = model_fn(new_x)
new_pred = torch.argmax(new_pred, 1)

# untargeted
new_x_untargeted = carlini_wagner_l2(model_fn, x, 10, targeted=False, y=y)
new_pred_untargeted = model_fn(new_x_untargeted)
new_pred_untargeted = torch.argmax(new_pred_untargeted, 1)
Loading

0 comments on commit 50a579b

Please sign in to comment.