Skip to content

Commit

Permalink
Commit old stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
dedeswim committed May 4, 2022
1 parent 49eab48 commit ef96532
Show file tree
Hide file tree
Showing 5 changed files with 767 additions and 119 deletions.
115 changes: 0 additions & 115 deletions configs/resnet18-adv-training.yaml

This file was deleted.

13 changes: 9 additions & 4 deletions datasets/imagenet_perturbations/imagenet_perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_CITATION = """
"""

MODELS_TO_NORMALIZE = {"resnet50", "deit_small_patch16_224", "xcit_small_12_p16_224_nonrobust"}
MODELS_TO_NORMALIZE = {"resnet50_nonrobust", "resnet50", "resnet50_fgsm", "xcit_small_12_p16_224_nonrobust"}


def load_model_from_gcs(checkpoint_path: str, model_name: str, **kwargs):
Expand Down Expand Up @@ -106,10 +106,15 @@ class ImagenetPerturbations(tfds.core.GeneratorBasedBuilder):
checkpoint_path="gs://robust-vits/xcit/best.pth.tar",
steps=1),
ImagenetPerturbationsConfig(name="resnet50",
model="resnet50",
model="adv_resnet50",
checkpoint_path="gs://robust-vits/external-checkpoints/advres50_gelu.pth",
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)),
ImagenetPerturbationsConfig(name="resnet50_nonrobust",
model="resnet50",
pretrained=True,
mean=constants.IMAGENET_DEFAULT_MEAN,
std=constants.IMAGENET_DEFAULT_STD),
ImagenetPerturbationsConfig(name="xcit_small_12_p16_224",
model="xcit_small_12_p16_224",
checkpoint_path="gs://robust-vits/xcit/best.pth.tar"),
Expand Down Expand Up @@ -144,7 +149,7 @@ def _split_generators(self, _: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
dev_env = initialize_device()

if self.builder_config.model == "resnet50":
if self.builder_config.name in {"adv_resnet50", "adv_resnet50_fgsm"}:
model = load_state_dict_from_gcs(resnet50(norm_layer=EightBN),
self.builder_config.checkpoint_path)
elif self.builder_config.checkpoint_path:
Expand All @@ -155,7 +160,7 @@ def _split_generators(self, _: tfds.download.DownloadManager):
raise ValueError(f"For {self.builder_config.name}, either the checkpoint"
"should be specified, or pretrained should be `True`")

if self.builder_config.model in MODELS_TO_NORMALIZE:
if self.builder_config.name in MODELS_TO_NORMALIZE:
model = normalize_model(model, self.builder_config.mean, self.builder_config.std)

model.to(dev_env.device)
Expand Down
372 changes: 372 additions & 0 deletions notebooks/advex-example.ipynb

Large diffs are not rendered by default.

Binary file added notebooks/castle.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
386 changes: 386 additions & 0 deletions notebooks/data-augmentations.ipynb

Large diffs are not rendered by default.

0 comments on commit ef96532

Please sign in to comment.