Skip to content

Commit

Permalink
add feature space version
Browse files Browse the repository at this point in the history
  • Loading branch information
fra31 committed Feb 24, 2022
1 parent 60d3a3b commit 1bd91bd
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ CUDA_VISIBLE_DEVICES=0 python eval.py --norm=L0 \
```
and for targeted attacks please use `--targeted --n_queries=100000 --alpha_init=0.1`. The target class is randomly chosen for each point.

To use an attack in the *feature* space please add `--use_feature_space` (in this case `k` indicates the number of features to modify).

As additional options the flag `--constant_schedule` uses a constant schedule for `alpha` instead of the piecewise constant decreasing one, while with `--seed=N` it is possible to set a custom random seed.

### Image-specific patches and frames
Expand Down
24 changes: 21 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import time
from datetime import datetime

from utils import SingleChannelModel

model_class_dict = {'pt_vgg': torch_models.vgg16_bn,
'pt_resnet': torch_models.resnet50,
}
Expand Down Expand Up @@ -65,6 +67,7 @@ def random_target_classes(y_pred, n_classes):
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--constant_schedule', action='store_true')
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--use_feature_space', action='store_true')

# Sparse-RS parameter
parser.add_argument('--alpha_init', type=float, default=.3)
Expand All @@ -74,7 +77,7 @@ def random_target_classes(y_pred, n_classes):
args = parser.parse_args()

if args.data_path is None:
args.data_path = "/home/scratch/datasets/imagenet/val"
args.data_path = "/scratch/datasets/imagenet/val"

args.eps = args.k + 0
args.bs = args.n_ex + 0
Expand Down Expand Up @@ -116,12 +119,23 @@ def random_target_classes(y_pred, n_classes):
if args.targeted or 'universal' in args.norm:
args.loss = 'ce'
data_loader = testiter if 'universal' in args.norm else None
if args.use_feature_space:
# reshape images to single color channel to perturb them individually
assert args.norm == 'L0'
bs, c, h, w = x_test.shape
x_test = x_test.view(bs, 1, h, w * c)
model = SingleChannelModel(model)
str_space = 'feature space'
else:
str_space = 'pixel space'

param_run = '{}_{}_{}_1_{}_nqueries_{:.0f}_pinit_{:.2f}_loss_{}_eps_{:.0f}_targeted_{}_targetclass_{}_seed_{:.0f}'.format(
args.attack, args.norm, args.model, args.n_ex, args.n_queries, args.p_init,
args.loss, args.eps, args.targeted, args.target_class, args.seed)
if args.constant_schedule:
param_run += '_constantpinit'
if args.use_feature_space:
param_run += '_featurespace'

from rs_attacks import RSAttack
adversary = RSAttack(model, norm=args.norm, eps=int(args.eps), verbose=True, n_queries=args.n_queries,
Expand Down Expand Up @@ -198,8 +212,8 @@ def random_target_classes(y_pred, n_classes):
adversary.logger.log('robust accuracy {:.2%}'.format(acc / args.n_ex))

res = (adv_complete - x_test != 0.).max(dim=1)[0].sum(dim=(1, 2))
adversary.logger.log('max L0 perturbation {:.0f} - nan in img {} - max img {:.5f} - min img {:.5f}'.format(
res.max(), (adv_complete != adv_complete).sum(), adv_complete.max(), adv_complete.min()))
adversary.logger.log('max L0 perturbation ({}) {:.0f} - nan in img {} - max img {:.5f} - min img {:.5f}'.format(
str_space, res.max(), (adv_complete != adv_complete).sum(), adv_complete.max(), adv_complete.min()))

ind_corrcl = pred == 1.
ind_succ = (pred_adv == 0.) * (pred == 1.)
Expand All @@ -215,6 +229,10 @@ def random_target_classes(y_pred, n_classes):

# save results depending on the threat model
if args.norm in ['L0', 'patches', 'frames']:
if args.use_feature_space:
# reshape perturbed images to original rgb format
bs, _, h, w = adv_complete.shape
adv_complete = adv_complete.view(bs, 3, h, w // 3)
torch.save({'adv': adv_complete, 'qr': qr_complete},
'{}/{}.pth'.format(savedir, param_run))

Expand Down
5 changes: 3 additions & 2 deletions rs_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
self.resample_loc = n_queries // 10 if resample_loc is None else resample_loc
self.data_loader = data_loader
self.update_loc_period = update_loc_period if not update_loc_period is None else 4 if not targeted else 10


def margin_and_loss(self, x, y):
"""
Expand Down Expand Up @@ -315,10 +316,10 @@ def attack_single_run(self, x, y):
else:
# if update is 1x1 make sure the sampled color is different from the current one
old_clr = x_new[img, :, np_set // w, np_set % w].clone()
assert old_clr.shape == (3, 1), print(old_clr)
assert old_clr.shape == (c, 1), print(old_clr)
new_clr = old_clr.clone()
while (new_clr == old_clr).all().item():
new_clr = self.random_choice([3, 1]).clone().clamp(0., 1.)
new_clr = self.random_choice([c, 1]).clone().clamp(0., 1.)
x_new[img, :, np_set // w, np_set % w] = new_clr.clone()

# compute loss of the new candidates
Expand Down
17 changes: 17 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import torch
import torch.nn as nn


class Logger():
def __init__(self, log_path):
self.log_path = log_path
Expand All @@ -9,3 +13,16 @@ def log(self, str_to_log):
f.write(str_to_log + '\n')
f.flush()


class SingleChannelModel():
""" reshapes images to rgb before classification
i.e. [N, 1, H, W x 3] -> [N, 3, H, W]
"""
def __init__(self, model):
if isinstance(model, nn.Module):
assert not model.training
self.model = model

def __call__(self, x):
return self.model(x.view(x.shape[0], 3, x.shape[2], x.shape[3] // 3))

0 comments on commit 1bd91bd

Please sign in to comment.