From 4429d372d0586808dd651fc336b29cc405cfb09c Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Tue, 27 Nov 2018 06:29:05 -0800 Subject: [PATCH] Fix bugs in SPSA early stopping (#906) * tweak doc * add to documentation of early_stop_loss_threshold * more detail to doc * fix bugs in early stopping * Update cleverhans/attacks_tf.py Co-Authored-By: goodfeli --- cleverhans/attacks/__init__.py | 42 +++++++++++++++++++++++----------- cleverhans/attacks_tf.py | 4 ++-- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/cleverhans/attacks/__init__.py b/cleverhans/attacks/__init__.py index 11b53c724..cd7faa055 100644 --- a/cleverhans/attacks/__init__.py +++ b/cleverhans/attacks/__init__.py @@ -2593,9 +2593,17 @@ def projected_optimization(loss_fn, :param project_perturbation: A function, which will be used to enforce some constraint. It should have the same signature as `_project_perturbation`. - :param early_stop_loss_threshold: A float or None. If specified, the - attack will end if the loss is below - `early_stop_loss_threshold`. + :param early_stop_loss_threshold: A float or None. If specified, the attack will end if the loss is below + `early_stop_loss_threshold`. + Enabling this option can have several different effects: + - Setting the threshold to 0. guarantees that if a successful attack is found, it is returned. + This increases the attack success rate, because without early stopping the optimizer can accidentally + bounce back to a point where the attack fails. + - Early stopping can make the attack run faster because it may run for fewer steps. + - Early stopping can make the attack run slower because the loss must be calculated at each step. + The loss is not calculated as part of the normal SPSA optimization procedure. + For most reasonable choices of hyperparameters, early stopping makes the attack much faster because + it decreases the number of steps dramatically. :param is_debug: A bool. If True, print debug info for attack progress. Returns: @@ -2635,20 +2643,28 @@ def wrapped_loss_fn(x): new_perturbation_list, new_optim_state = optimizer.minimize( wrapped_loss_fn, [perturbation], optim_state) - loss = reduce_mean(wrapped_loss_fn(perturbation), axis=0) - if is_debug: - with tf.device("/cpu:0"): - loss = tf.Print(loss, [loss], "Total batch loss") projected_perturbation = project_perturbation(new_perturbation_list[0], epsilon, input_image, clip_min=clip_min, clip_max=clip_max) - with tf.control_dependencies([loss]): - i = tf.identity(i) - if early_stop_loss_threshold: - i = tf.cond( - tf.less(loss, early_stop_loss_threshold), - lambda: float(num_steps), lambda: i) + + # Be careful with this bool. A value of 0. is a valid threshold but evaluates to False, so we must explicitly + # check whether the value is None. + early_stop = early_stop_loss_threshold is not None + compute_loss = is_debug or early_stop + # Don't waste time building the loss graph if we're not going to use it + if compute_loss: + # NOTE: this step is not actually redundant with the optimizer step. + # SPSA calculates the loss at randomly perturbed points but doesn't calculate the loss at the current point. + loss = reduce_mean(wrapped_loss_fn(projected_perturbation), axis=0) + + if is_debug: + with tf.device("/cpu:0"): + loss = tf.Print(loss, [loss], "Total batch loss") + + if early_stop: + i = tf.cond(tf.less(loss, early_stop_loss_threshold), lambda: float(num_steps), lambda: i) + return i + 1, projected_perturbation, nest.flatten(new_optim_state) def cond(i, *_): diff --git a/cleverhans/attacks_tf.py b/cleverhans/attacks_tf.py index b41c90550..a7a7b5aed 100644 --- a/cleverhans/attacks_tf.py +++ b/cleverhans/attacks_tf.py @@ -1485,8 +1485,8 @@ class TensorOptimizer(object): behaviors when being assigned multiple times within a single sess.run() call, particularly in Distributed TF, so this avoids thinking about those issues. These are helper classes for the `projected_optimization` - method. Apart from not using Variables, they follow the same interface as - tf.Optimizer. + method. Apart from not using Variables, they follow an interface very + similar to tf.Optimizer. """ def _compute_gradients(self, loss_fn, x, unused_optim_state):