diff --git a/src/arg_parser.py b/src/arg_parser.py index 17a5fd7..23e0f99 100644 --- a/src/arg_parser.py +++ b/src/arg_parser.py @@ -546,7 +546,7 @@ help='Uses torch_xla\'s MP loader') -def parse_args(): +def parse_args(additional_args=None): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() if args_config.config: @@ -555,9 +555,10 @@ def parse_args(): parser.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual - # defaults will have been overridden if config file specified. + # defaults will have been overridden if config file specified. + if additional_args is not None: + remaining += additional_args args = parser.parse_args(remaining) - # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) return args, args_text diff --git a/src/models/models.py b/src/models/models.py index 70b8869..40bc8e9 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -3,6 +3,8 @@ from timm.models.registry import register_model from torch import nn +from src import utils + default_cfgs = { 'cait_s12_224': cait._cfg(input_size=(3, 224, 224)), 'xcit_medium_12_p16_224': xcit._cfg(), @@ -10,6 +12,7 @@ 'xcit_large_12_h8_p16_224': xcit._cfg(), 'xcit_small_12_p4_32': xcit._cfg(input_size=(3, 32, 32)), 'xcit_medium_12_p4_32': xcit._cfg(input_size=(3, 32, 32)), + 'xcit_large_12_p4_32': xcit._cfg(input_size=(3, 32, 32)), 'resnet18_gelu': resnet._cfg(), 'resnet50_gelu': resnet._cfg(interpolation='bicubic', crop_pct=0.95), 'resnext152_32x8d': resnet._cfg(input_size=(3, 380, 380)) @@ -77,9 +80,7 @@ def xcit_small_12_p8_32(pretrained=False, **kwargs): **kwargs) model = xcit._create_xcit('xcit_small_12_p4_32', pretrained=pretrained, **model_kwargs) assert isinstance(model, xcit.XCiT) - # Adapt ConvPatchEmbed module - model.patch_embed.patch_size = 8 - model.patch_embed.proj[0][0].stride = (1, 1) + model = utils.adapt_model_patches(model, 8) return model @@ -95,10 +96,7 @@ def xcit_small_12_p4_32(pretrained=False, **kwargs): **kwargs) model = xcit._create_xcit('xcit_small_12_p4_32', pretrained=pretrained, **model_kwargs) assert isinstance(model, xcit.XCiT) - # Adapt ConvPatchEmbed module - model.patch_embed.patch_size = 4 - for conv_index in [0, 2]: - model.patch_embed.proj[conv_index][0].stride = (1, 1) + model = utils.adapt_model_patches(model, 4) return model @@ -114,10 +112,22 @@ def xcit_medium_12_p4_32(pretrained=False, **kwargs): model = xcit._create_xcit('xcit_medium_12_p4_32', pretrained=pretrained, **model_kwargs) # TODO: make this a function assert isinstance(model, xcit.XCiT) - # Adapt ConvPatchEmbed module - model.patch_embed.patch_size = 4 - for conv_index in [0, 2]: - model.patch_embed.proj[conv_index][0].stride = (1, 1) + model = utils.adapt_model_patches(model, 4) + return model + + +@register_model +def xcit_large_12_p4_32(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=16, + embed_dim=768, + depth=12, + num_heads=16, + eta=1.0, + tokens_norm=True, + **kwargs) + model = xcit._create_xcit('xcit_large_12_p16_224', pretrained=pretrained, **model_kwargs) + assert isinstance(model, xcit.XCiT) + model = utils.adapt_model_patches(model, 4) return model @@ -132,10 +142,7 @@ def xcit_small_12_p2_32(pretrained=False, **kwargs): **kwargs) model = xcit._create_xcit('xcit_small_12_p2_32', pretrained=pretrained, **model_kwargs) assert isinstance(model, xcit.XCiT) - # Adapt ConvPatchEmbed module - model.patch_embed.patch_size = 2 - for conv_index in [0, 2, 4]: - model.patch_embed.proj[conv_index][0].stride = (1, 1) + model = utils.adapt_model_patches(model, 2) return model diff --git a/src/utils.py b/src/utils.py index 6d3982c..f35ba2f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -14,6 +14,7 @@ from timm.data import PreprocessCfg from timm.data.fetcher import Fetcher from timm.data.prefetcher_cuda import PrefetcherCuda +from timm.models import xcit from torch import nn import src.attacks as attacks @@ -79,7 +80,6 @@ def check_bucket_zone(data_dir, prefix): class GCSSummaryCsv(bits.monitor.SummaryCsv): """SummaryCSV version to work with GCS""" - def __init__(self, output_dir, filename='summary.csv'): super().__init__(output_dir, filename) @@ -93,7 +93,6 @@ def update(self, row_dict): class ComputeLossFn(nn.Module): - def __init__(self, loss_fn: nn.Module): super().__init__() self.loss_fn = loss_fn @@ -138,7 +137,6 @@ class MyPreprocessCfg(PreprocessCfg): class ImageNormalizer(nn.Module): """From https://github.com/RobustBench/robustbench/blob/master/robustbench/model_zoo/architectures/utils_architectures.py#L8""" - def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None: super(ImageNormalizer, self).__init__() @@ -158,7 +156,6 @@ def normalize_model(model: nn.Module, mean: Tuple[float, float, float], std: Tup class CombinedLoaders: - def __init__(self, loader_1: Union[Fetcher, PrefetcherCuda], loader_2: Union[Fetcher, PrefetcherCuda]): self.loader_1 = loader_1 self.loader_2 = loader_2 @@ -233,3 +230,14 @@ def interpolate_position_embeddings(model: nn.Module, checkpoint_model: Dict[str pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) return new_pos_embed + + +def adapt_model_patches(model: xcit.XCiT, new_patch_size: int): + to_divide = model.patch_embed.patch_size / new_patch_size + assert int(to_divide) == to_divide, "The new patch size should divide the original patch size" + to_divide = int(to_divide) + assert to_divide % 2 == 0, "The ratio between the original patch size and the new patch size should be divisible by 2" + for conv_index in range(0, to_divide, 2): + model.patch_embed.proj[conv_index][0].stride = (1, 1) + model.patch_embed.patch_size = new_patch_size + return model diff --git a/tests/test_utils.py b/tests/test_utils.py index 29c690d..2391b7f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,13 @@ +import pytest from src import setup_task +from src.utils import adapt_model_patches from src.arg_parser import parse_args +from timm.models import xcit + def test_resolve_attack_cfg(): - args, _ = parse_args() + args, _ = parse_args(additional_args=["foo"]) args.attack_eps = 4 args.attack_steps = 10 attack_cfg = setup_task.resolve_attack_cfg(args) @@ -16,7 +20,7 @@ def test_resolve_attack_cfg(): def test_resolve_attack_cfg_attack_lr(): - args, _ = parse_args() + args, _ = parse_args(additional_args=["foo"]) args.attack_eps = 4 args.attack_steps = 10 args.attack_lr = 15 @@ -26,7 +30,7 @@ def test_resolve_attack_cfg_attack_lr(): def test_resolve_attack_cfg_eval_eps(): - args, _ = parse_args() + args, _ = parse_args(additional_args=["foo"]) args.eval_attack_eps = 8 args.attack_steps = 10 attack_cfg = setup_task.resolve_attack_cfg(args, eval=True) @@ -39,8 +43,49 @@ def test_resolve_attack_cfg_eval_eps(): def test_resolve_attack_cfg_eval_name(): - args, _ = parse_args() + args, _ = parse_args(additional_args=["foo"]) args.attack = "targeted_pgd" attack_cfg = setup_task.resolve_attack_cfg(args, eval=True) assert attack_cfg.name == "pgd" + + +def test_adapt_model_patches_exception(): + model = xcit._create_xcit('xcit_small_12_p16_224') + patch_size = 6 + assert isinstance(model, xcit.XCiT) + with pytest.raises(AssertionError): + adapt_model_patches(model, patch_size) + + +def test_adapt_model_patches_2(): + model = xcit._create_xcit('xcit_small_12_p16_224') + patch_size = 2 + assert isinstance(model, xcit.XCiT) + modified_model = adapt_model_patches(model, patch_size) + assert modified_model.patch_embed.patch_size == patch_size + assert modified_model.patch_embed.proj[0][0].stride == (1, 1) + assert modified_model.patch_embed.proj[2][0].stride == (1, 1) + assert modified_model.patch_embed.proj[4][0].stride == (1, 1) + + +def test_adapt_model_patches_4(): + model = xcit._create_xcit('xcit_small_12_p16_224') + patch_size = 4 + assert isinstance(model, xcit.XCiT) + modified_model = adapt_model_patches(model, patch_size) + assert modified_model.patch_embed.patch_size == patch_size + assert modified_model.patch_embed.proj[0][0].stride == (1, 1) + assert modified_model.patch_embed.proj[2][0].stride == (1, 1) + assert modified_model.patch_embed.proj[4][0].stride == (2, 2) + + +def test_adapt_model_patches_8(): + model = xcit._create_xcit('xcit_small_12_p16_224') + patch_size = 8 + assert isinstance(model, xcit.XCiT) + modified_model = adapt_model_patches(model, patch_size) + assert modified_model.patch_embed.patch_size == patch_size + assert modified_model.patch_embed.proj[0][0].stride == (1, 1) + assert modified_model.patch_embed.proj[2][0].stride == (2, 2) + assert modified_model.patch_embed.proj[4][0].stride == (2, 2)