Skip to content

Commit

Permalink
Add results for attack convergence for all the models
Browse files Browse the repository at this point in the history
  • Loading branch information
dedeswim committed Aug 17, 2022
1 parent 919dfdc commit 785e5b1
Show file tree
Hide file tree
Showing 4 changed files with 647,045 additions and 331,165 deletions.
8 changes: 4 additions & 4 deletions attack_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def main():
csv_writer = utils.GCSSummaryCsv(output_path.parent, filename=output_path.name)

model = timm.create_model(args.model, pretrained=not args.checkpoint, checkpoint_path=args.checkpoint)
model = dev_env.to_device(model)
model.eval()

criterion = dev_env.to_device(nn.CrossEntropyLoss(reduction='none'))

Expand All @@ -57,8 +55,7 @@ def main():

dataset = create_dataset(root=args.data,
name=args.dataset,
split=args.split,
download=args.dataset_download)
split=args.split)

data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
data_config['normalize'] = not (args.no_normalize or args.normalize_model)
Expand All @@ -67,6 +64,9 @@ def main():
mean = args.mean or data_config["mean"]
std = args.std or data_config["std"]
model = utils.normalize_model(model, mean=mean, std=std)

model = dev_env.to_device(model)
model.eval()

test_time_pool = False
if args.test_pool:
Expand Down
Loading

0 comments on commit 785e5b1

Please sign in to comment.