Skip to content

Commit

Permalink
Increase test coverage (#680)
Browse files Browse the repository at this point in the history
* Fix incorrect short name of adam pgd
* Increase test coverage of FMN
  • Loading branch information
zimmerrol authored Apr 2, 2022
1 parent 3536742 commit 1c55ee4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 7 deletions.
8 changes: 4 additions & 4 deletions foolbox/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
L2ProjectedGradientDescentAttack,
LinfProjectedGradientDescentAttack,
L1AdamProjectedGradientDescentAttack,
L2PAdamProjectedGradientDescentAttack,
L2AdamProjectedGradientDescentAttack,
LinfAdamProjectedGradientDescentAttack,
)
from .basic_iterative_method import ( # noqa: F401
Expand Down Expand Up @@ -94,7 +94,7 @@
LinfPGD = LinfProjectedGradientDescentAttack
PGD = LinfPGD

L1AdamPGD = L1ProjectedGradientDescentAttack
L2AdamPGD = L2ProjectedGradientDescentAttack
LinfAdamPGD = LinfProjectedGradientDescentAttack
L1AdamPGD = L1AdamProjectedGradientDescentAttack
L2AdamPGD = L2AdamProjectedGradientDescentAttack
LinfAdamPGD = LinfAdamProjectedGradientDescentAttack
AdamPGD = LinfAdamPGD
1 change: 0 additions & 1 deletion foolbox/attacks/dataset_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def run(
for i in range(batch_size):
indices = list(range(batch_size))
indices.remove(i)
indices = list(indices)
np.random.shuffle(indices)
index_pools.append(indices)

Expand Down
2 changes: 1 addition & 1 deletion foolbox/attacks/projected_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_optimizer(self, x: ep.Tensor, stepsize: float) -> Optimizer:
)


class L2PAdamProjectedGradientDescentAttack(L2ProjectedGradientDescentAttack):
class L2AdamProjectedGradientDescentAttack(L2ProjectedGradientDescentAttack):
"""L2 Projected Gradient Descent with Adam optimizer
Args:
Expand Down
47 changes: 46 additions & 1 deletion tests/test_fast_minimum_norm_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from foolbox.devutils import flatten
from foolbox.attacks.fast_minimum_norm import FMNAttackLp
import pytest

import numpy as np
from conftest import ModeAndDataAndDescription


Expand All @@ -19,6 +19,7 @@ def get_attack_id(x: Tuple[FMNAttackLp, Union[int, float]]) -> str:
(fa.L1FMNAttack(steps=20), 1),
(fa.L2FMNAttack(steps=20), 2),
(fa.LInfFMNAttack(steps=20), ep.inf),
(fa.LInfFMNAttack(steps=20, min_stepsize=1.0 / 100), ep.inf),
]


Expand Down Expand Up @@ -51,3 +52,47 @@ def test_fast_minimum_norm_untargeted_attack(
assert fbn.accuracy(fmodel, advs, y) < fbn.accuracy(fmodel, x, y)
assert fbn.accuracy(fmodel, advs, y) <= fbn.accuracy(fmodel, init_advs, y)
assert is_smaller.any()


@pytest.mark.parametrize("attack_and_p", attacks, ids=get_attack_id)
def test_fast_minimum_norm_targeted_attack(
fmodel_and_data_ext_for_attacks: ModeAndDataAndDescription,
attack_and_p: Tuple[FMNAttackLp, Union[int, float]],
) -> None:

(fmodel, x, y), real, low_dimensional_input = fmodel_and_data_ext_for_attacks

if isinstance(x, ep.NumPyTensor):
pytest.skip()

x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower)
fmodel = fmodel.transform_bounds((0, 1))

unique_preds = np.unique(fmodel(x).argmax(-1).numpy())
target_classes = ep.from_numpy(
y,
np.array(
[
unique_preds[(np.argmax(y_it == unique_preds) + 1) % len(unique_preds)]
for y_it in y.numpy()
]
),
)
criterion = fbn.TargetedMisclassification(target_classes)
adv_before_attack = criterion(x, fmodel(x))
assert not adv_before_attack.all()

init_attack = fa.DatasetAttack()
init_attack.feed(fmodel, x)
init_advs = init_attack.run(fmodel, x, criterion)

attack, p = attack_and_p
advs = attack.run(fmodel, x, criterion, starting_points=init_advs)

init_norms = ep.norms.lp(flatten(init_advs - x), p=p, axis=-1)
norms = ep.norms.lp(flatten(advs - x), p=p, axis=-1)

is_smaller = norms < init_norms

assert fbn.accuracy(fmodel, advs, target_classes) == 1.0
assert is_smaller.any()

0 comments on commit 1c55ee4

Please sign in to comment.