Skip to content

Commit

Permalink
Fix bugs in SPSA early stopping (#906)
Browse files Browse the repository at this point in the history
* tweak doc

* add to documentation of early_stop_loss_threshold

* more detail to doc

* fix bugs in early stopping

* Update cleverhans/attacks_tf.py

Co-Authored-By: goodfeli <[email protected]>
  • Loading branch information
goodfeli authored Nov 27, 2018
1 parent bfae77b commit 4429d37
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
42 changes: 29 additions & 13 deletions cleverhans/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,9 +2593,17 @@ def projected_optimization(loss_fn,
:param project_perturbation: A function, which will be used to enforce
some constraint. It should have the same
signature as `_project_perturbation`.
:param early_stop_loss_threshold: A float or None. If specified, the
attack will end if the loss is below
`early_stop_loss_threshold`.
:param early_stop_loss_threshold: A float or None. If specified, the attack will end if the loss is below
`early_stop_loss_threshold`.
Enabling this option can have several different effects:
- Setting the threshold to 0. guarantees that if a successful attack is found, it is returned.
This increases the attack success rate, because without early stopping the optimizer can accidentally
bounce back to a point where the attack fails.
- Early stopping can make the attack run faster because it may run for fewer steps.
- Early stopping can make the attack run slower because the loss must be calculated at each step.
The loss is not calculated as part of the normal SPSA optimization procedure.
For most reasonable choices of hyperparameters, early stopping makes the attack much faster because
it decreases the number of steps dramatically.
:param is_debug: A bool. If True, print debug info for attack progress.
Returns:
Expand Down Expand Up @@ -2635,20 +2643,28 @@ def wrapped_loss_fn(x):

new_perturbation_list, new_optim_state = optimizer.minimize(
wrapped_loss_fn, [perturbation], optim_state)
loss = reduce_mean(wrapped_loss_fn(perturbation), axis=0)
if is_debug:
with tf.device("/cpu:0"):
loss = tf.Print(loss, [loss], "Total batch loss")
projected_perturbation = project_perturbation(new_perturbation_list[0],
epsilon, input_image,
clip_min=clip_min,
clip_max=clip_max)
with tf.control_dependencies([loss]):
i = tf.identity(i)
if early_stop_loss_threshold:
i = tf.cond(
tf.less(loss, early_stop_loss_threshold),
lambda: float(num_steps), lambda: i)

# Be careful with this bool. A value of 0. is a valid threshold but evaluates to False, so we must explicitly
# check whether the value is None.
early_stop = early_stop_loss_threshold is not None
compute_loss = is_debug or early_stop
# Don't waste time building the loss graph if we're not going to use it
if compute_loss:
# NOTE: this step is not actually redundant with the optimizer step.
# SPSA calculates the loss at randomly perturbed points but doesn't calculate the loss at the current point.
loss = reduce_mean(wrapped_loss_fn(projected_perturbation), axis=0)

if is_debug:
with tf.device("/cpu:0"):
loss = tf.Print(loss, [loss], "Total batch loss")

if early_stop:
i = tf.cond(tf.less(loss, early_stop_loss_threshold), lambda: float(num_steps), lambda: i)

return i + 1, projected_perturbation, nest.flatten(new_optim_state)

def cond(i, *_):
Expand Down
4 changes: 2 additions & 2 deletions cleverhans/attacks_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,8 +1485,8 @@ class TensorOptimizer(object):
behaviors when being assigned multiple times within a single sess.run()
call, particularly in Distributed TF, so this avoids thinking about those
issues. These are helper classes for the `projected_optimization`
method. Apart from not using Variables, they follow the same interface as
tf.Optimizer.
method. Apart from not using Variables, they follow an interface very
similar to tf.Optimizer.
"""

def _compute_gradients(self, loss_fn, x, unused_optim_state):
Expand Down

0 comments on commit 4429d37

Please sign in to comment.