Skip to content

Commit

Permalink
Merge pull request #1149 from CNOCycle/tf2/pgd_attack
Browse files Browse the repository at this point in the history
[TF2] PGD attack
  • Loading branch information
alkaet authored Jan 25, 2021
2 parents f5adfdd + ea5c3a3 commit 68d6bab
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ The following authors contributed 100 lines or more (ordered according to the Gi
* Bogdan Kulynych (EPFL)
* Erfan Noury (UMBC)
* Robert Wagner (Case Western Reserve University)
* Erh-Chung Chen (National Tsing Hua University)

## Copyright

Expand Down
14 changes: 10 additions & 4 deletions cleverhans/future/tf2/attacks/fast_gradient_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import tensorflow as tf


def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
def fast_gradient_method(model_fn, x, eps, norm, loss_fn=None, clip_min=None, clip_max=None, y=None,
targeted=False, sanity_checks=False):
"""
Tensorflow 2.0 implementation of the Fast Gradient Method.
:param model_fn: a callable that takes an input tensor and returns the model logits.
:param x: input tensor.
:param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
:param norm: Order of the norm (mimics NumPy). Possible values: np.inf, 1 or 2.
:param loss_fn: (optional) callable. Loss function that takes (labels, logits) as arguments and returns loss.
default function is 'tf.nn.sparse_softmax_cross_entropy_with_logits'
:param clip_min: (optional) float. Minimum float value for adversarial example components.
:param clip_max: (optional) float. Maximum float value for adversarial example components.
:param y: (optional) Tensor with true labels. If targeted is true, then provide the
Expand All @@ -29,6 +31,9 @@ def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y
if norm not in [np.inf, 1, 2]:
raise ValueError("Norm order must be either np.inf, 1, or 2.")

if loss_fn is None:
loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits

asserts = []

# If a data range was specified, check that the input was in that range
Expand All @@ -42,7 +47,7 @@ def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y
# Using model predictions as ground truth to avoid label leaking
y = tf.argmax(model_fn(x), 1)

grad = compute_gradient(model_fn, x, y, targeted)
grad = compute_gradient(model_fn, loss_fn, x, y, targeted)

optimal_perturbation = optimize_linear(grad, eps, norm)
# Add perturbation to original example to obtain adversarial example
Expand All @@ -63,18 +68,19 @@ def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y
# Not using the decorator here, or letting the user wrap the attack in tf.function is way
# slower on Tensorflow 2.0.0-alpha0.
@tf.function
def compute_gradient(model_fn, x, y, targeted):
def compute_gradient(model_fn, loss_fn, x, y, targeted):
"""
Computes the gradient of the loss with respect to the input tensor.
:param model_fn: a callable that takes an input tensor and returns the model logits.
:param loss_fn: loss function that takes (labels, logits) as arguments and returns loss.
:param x: input tensor
:param y: Tensor with true labels. If targeted is true, then provide the target label.
:param targeted: bool. Is the attack targeted or untargeted? Untargeted, the default, will
try to make the label incorrect. Targeted will instead try to move in the
direction of being more like y.
:return: A tensor containing the gradient of the loss with respect to the input tensor.
"""
loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits

with tf.GradientTape() as g:
g.watch(x)
# Compute loss
Expand Down
24 changes: 18 additions & 6 deletions cleverhans/future/tf2/attacks/projected_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import tensorflow as tf

from cleverhans.future.tf2.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.future.tf2.utils_tf import clip_eta
from cleverhans.future.tf2.utils_tf import clip_eta, random_lp_vector


def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,
def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm, loss_fn=None,
clip_min=None, clip_max=None, y=None, targeted=False,
rand_init=None, rand_minmax=0.3, sanity_checks=True):
rand_init=None, rand_minmax=None, sanity_checks=False):
"""
This class implements either the Basic Iterative Method
(Kurakin et al. 2016) when rand_init is set to 0. or the
Expand All @@ -22,6 +22,8 @@ def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,
:param eps_iter: step size for each attack iteration
:param nb_iter: Number of attack iterations.
:param norm: Order of the norm (mimics NumPy). Possible values: np.inf, 1 or 2.
:param loss_fn: (optional) callable. loss function that takes (labels, logits) as arguments and returns loss.
default function is 'tf.nn.sparse_softmax_cross_entropy_with_logits'
:param clip_min: (optional) float. Minimum float value for adversarial example components.
:param clip_max: (optional) float. Maximum float value for adversarial example components.
:param y: (optional) Tensor with true labels. If targeted is true, then provide the
Expand All @@ -32,6 +34,11 @@ def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,
:param targeted: (optional) bool. Is the attack targeted or untargeted?
Untargeted, the default, will try to make the label incorrect.
Targeted will instead try to move in the direction of being more like y.
:param rand_init: (optional) float. Start the gradient descent from a point chosen
uniformly at random in the norm ball of radius
rand_init_eps
:param rand_minmax: (optional) float. Size of the norm ball from which
the initial starting point is chosen. Defaults to eps
:param sanity_checks: bool, if True, include asserts (Turn them off to use less runtime /
memory or for unit tests that intentionally pass strange input)
:return: a tensor for the adversarial example
Expand All @@ -47,6 +54,9 @@ def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,
if norm not in [np.inf, 2]:
raise ValueError("Norm order must be either np.inf or 2.")

if loss_fn is None:
loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits

asserts = []

# If a data range was specified, check that the input was in that range
Expand All @@ -57,9 +67,11 @@ def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,
asserts.append(tf.math.less_equal(x, clip_max))

# Initialize loop variables
if rand_minmax is None:
rand_minmax = eps

if rand_init:
rand_minmax = eps
eta = tf.random.uniform(x.shape, -rand_minmax, rand_minmax)
eta = random_lp_vector(tf.shape(x), norm, tf.cast(rand_minmax, x.dtype), dtype=x.dtype)
else:
eta = tf.zeros_like(x)

Expand All @@ -75,7 +87,7 @@ def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,

i = 0
while i < nb_iter:
adv_x = fast_gradient_method(model_fn, adv_x, eps_iter, norm, clip_min=clip_min,
adv_x = fast_gradient_method(model_fn, adv_x, eps_iter, norm, loss_fn, clip_min=clip_min,
clip_max=clip_max, y=y, targeted=targeted)

# Clipping perturbation eta to norm norm ball
Expand Down
77 changes: 77 additions & 0 deletions cleverhans/future/tf2/utils_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,80 @@ def clip_eta(eta, norm, eps):
factor = tf.minimum(1., tf.math.divide(eps, norm))
eta = eta * factor
return eta


def random_exponential(shape, rate=1.0, dtype=tf.float32, seed=None):
"""
Helper function to sample from the exponential distribution, which is not
included in core TensorFlow.
shape: shape of the sampled tensor.
:rate: (optional) rate parameter of the exponential distribution, defaults to 1.0.
:dtype: (optional) data type of the sempled tensor, defaults to tf.float32.
:seed: (optional) custom seed to be used for sampling.
"""
return tf.random.gamma(shape, alpha=1, beta=1. / rate, dtype=dtype, seed=seed)


def random_laplace(shape, loc=0.0, scale=1.0, dtype=tf.float32, seed=None):
"""
Helper function to sample from the Laplace distribution, which is not
included in core TensorFlow.
:shape: shape of the sampled tensor.
:loc: (optional) mean of the laplace distribution, defaults to 0.0.
:scale: (optional) scale parameter of the laplace diustribution, defaults to 1.0.
:dtype: (optional) data type of the sempled tensor, defaults to tf.float32.
:seed: (optional) custom seed to be used for sampling.
"""
z1 = random_exponential(shape, 1. / scale, dtype=dtype, seed=seed)
z2 = random_exponential(shape, 1. / scale, dtype=dtype, seed=seed)
return z1 - z2 + loc

def random_lp_vector(shape, ord, eps, dtype=tf.float32, seed=None):
"""
Helper function to generate uniformly random vectors from a norm ball of
radius epsilon.
:param shape: Output shape of the random sample. The shape is expected to be
of the form `(n, d1, d2, ..., dn)` where `n` is the number of
i.i.d. samples that will be drawn from a norm ball of dimension
`d1*d1*...*dn`.
:param ord: Order of the norm (mimics Numpy).
Possible values: np.inf, 1 or 2.
:param eps: Epsilon, radius of the norm ball.
:param dtype: (optional) type of the tensor.
:param seed: (optional) integer.
"""
if ord not in [np.inf, 1, 2]:
raise ValueError('ord must be np.inf, 1, or 2.')

if ord == np.inf:
r = tf.random.uniform(shape, -eps, eps, dtype=dtype, seed=seed)
else:

# For ord=1 and ord=2, we use the generic technique from
# (Calafiore et al. 1998) to sample uniformly from a norm ball.
# Paper link (Calafiore et al. 1998):
# https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=758215&tag=1
# We first sample from the surface of the norm ball, and then scale by
# a factor `w^(1/d)` where `w~U[0,1]` is a standard uniform random variable
# and `d` is the dimension of the ball. In high dimensions, this is roughly
# equivalent to sampling from the surface of the ball.

dim = tf.reduce_prod(shape[1:])

if ord == 1:
x = random_laplace((shape[0], dim), loc=1.0, scale=1.0, dtype=dtype,
seed=seed)
norm = tf.reduce_sum(tf.abs(x), axis=-1, keepdims=True)
elif ord == 2:
x = tf.random.normal((shape[0], dim), dtype=dtype, seed=seed)
norm = tf.sqrt(tf.reduce_sum(tf.square(x), axis=-1, keepdims=True))
else:
raise ValueError('ord must be np.inf, 1, or 2.')

w = tf.pow(tf.random.uniform((shape[0], 1), dtype=dtype, seed=seed),
1.0 / tf.cast(dim, dtype))
r = eps * tf.reshape(w * x / norm, shape)

return r

0 comments on commit 68d6bab

Please sign in to comment.