diff --git a/cleverhans/utils.py b/cleverhans/utils.py index 934472a05..9185615d1 100644 --- a/cleverhans/utils.py +++ b/cleverhans/utils.py @@ -73,4 +73,4 @@ def batch_indices(batch_nb, data_length, batch_size): start -= shift end -= shift - return start, end \ No newline at end of file + return start, end diff --git a/cleverhans/utils_tf.py b/cleverhans/utils_tf.py index c977180bd..8f9a49ad2 100644 --- a/cleverhans/utils_tf.py +++ b/cleverhans/utils_tf.py @@ -40,8 +40,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False, :param X_train: numpy array with training inputs :param Y_train: numpy array with training outputs :param save: Boolean controling the save operation - :param predictions_adv: if set with the adversarial example tensor, - will run adversarial training + :param predictions_adv: if set with the adversarial example tensor, + will run adversarial training :return: True if model trained """ print "Starting model training using TensorFlow." @@ -63,7 +63,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False, print("Epoch " + str(epoch)) # Compute number of batches - nb_batches = int(math.ceil(len(X_train) / FLAGS.batch_size)) + nb_batches = int(math.ceil(float(len(X_train)) / FLAGS.batch_size)) + assert nb_batches * FLAGS.batch_size >= len(X_train) prev = time.time() for batch in range(nb_batches): @@ -80,6 +81,7 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False, train_step.run(feed_dict={x: X_train[start:end], y: Y_train[start:end], keras.backend.learning_phase(): 1}) + assert end >= len(X_train) # Check that all examples were used if save: @@ -112,21 +114,29 @@ def tf_model_eval(sess, x, y, model, X_test, Y_test): with sess.as_default(): # Compute number of batches - nb_batches = int(math.ceil(len(X_test) / FLAGS.batch_size)) + nb_batches = int(math.ceil(float(len(X_test)) / FLAGS.batch_size)) + assert nb_batches * FLAGS.batch_size >= len(X_test) for batch in range(nb_batches): if batch % 100 == 0 and batch > 0: print("Batch " + str(batch)) - # Compute batch start and end indices - start, end = batch_indices(batch, len(X_test), FLAGS.batch_size) + # Must not use the `batch_indices` function here, because it + # repeats some examples. + # It's acceptable to repeat during training, but not eval. + start = batch * FLAGS.batch_size + end = min(len(X_test), start + FLAGS.batch_size) + cur_batch_size = end - start + 1 - accuracy += acc_value.eval(feed_dict={x: X_test[start:end], + # The last batch may be smaller than all others, so we need to + # account for variable batch size here + accuracy += cur_batch_size * acc_value.eval(feed_dict={x: X_test[start:end], y: Y_test[start:end], keras.backend.learning_phase(): 0}) + assert end >= len(X_test) - # Divide by number of batches to get final value - accuracy /= nb_batches + # Divide by number of examples to get final value + accuracy /= len(X_test) return accuracy diff --git a/tests/test_mnist_accuracy.py b/tests/test_mnist_accuracy.py index b8024a4d3..52d4f1f6c 100644 --- a/tests/test_mnist_accuracy.py +++ b/tests/test_mnist_accuracy.py @@ -51,11 +51,11 @@ def main(argv=None): # Train an MNIST model tf_model_train(sess, x, y, predictions, X_train, Y_train) - + # Evaluate the accuracy of the MNIST model on legitimate test examples accuracy = tf_model_eval(sess, x, y, predictions, X_test, Y_test) - assert float(accuracy) >= 0.97 - - + assert float(accuracy) >= 0.97, accuracy + + if __name__ == '__main__': app.run()