Skip to content

Commit

Permalink
Add xcit_large_12_p4, add fn to adapt patch size
Browse files Browse the repository at this point in the history
  • Loading branch information
dedeswim committed Apr 17, 2022
1 parent e2b4613 commit f90de6a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 26 deletions.
7 changes: 4 additions & 3 deletions src/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@
help='Uses torch_xla\'s MP loader')


def parse_args():
def parse_args(additional_args=None):
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
Expand All @@ -555,9 +555,10 @@ def parse_args():
parser.set_defaults(**cfg)

# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
# defaults will have been overridden if config file specified.
if additional_args is not None:
remaining += additional_args
args = parser.parse_args(remaining)

# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
37 changes: 22 additions & 15 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from timm.models.registry import register_model
from torch import nn

from src import utils

default_cfgs = {
'cait_s12_224': cait._cfg(input_size=(3, 224, 224)),
'xcit_medium_12_p16_224': xcit._cfg(),
'xcit_large_12_p16_224': xcit._cfg(),
'xcit_large_12_h8_p16_224': xcit._cfg(),
'xcit_small_12_p4_32': xcit._cfg(input_size=(3, 32, 32)),
'xcit_medium_12_p4_32': xcit._cfg(input_size=(3, 32, 32)),
'xcit_large_12_p4_32': xcit._cfg(input_size=(3, 32, 32)),
'resnet18_gelu': resnet._cfg(),
'resnet50_gelu': resnet._cfg(interpolation='bicubic', crop_pct=0.95),
'resnext152_32x8d': resnet._cfg(input_size=(3, 380, 380))
Expand Down Expand Up @@ -77,9 +80,7 @@ def xcit_small_12_p8_32(pretrained=False, **kwargs):
**kwargs)
model = xcit._create_xcit('xcit_small_12_p4_32', pretrained=pretrained, **model_kwargs)
assert isinstance(model, xcit.XCiT)
# Adapt ConvPatchEmbed module
model.patch_embed.patch_size = 8
model.patch_embed.proj[0][0].stride = (1, 1)
model = utils.adapt_model_patches(model, 8)
return model


Expand All @@ -95,10 +96,7 @@ def xcit_small_12_p4_32(pretrained=False, **kwargs):
**kwargs)
model = xcit._create_xcit('xcit_small_12_p4_32', pretrained=pretrained, **model_kwargs)
assert isinstance(model, xcit.XCiT)
# Adapt ConvPatchEmbed module
model.patch_embed.patch_size = 4
for conv_index in [0, 2]:
model.patch_embed.proj[conv_index][0].stride = (1, 1)
model = utils.adapt_model_patches(model, 4)
return model


Expand All @@ -114,10 +112,22 @@ def xcit_medium_12_p4_32(pretrained=False, **kwargs):
model = xcit._create_xcit('xcit_medium_12_p4_32', pretrained=pretrained, **model_kwargs)
# TODO: make this a function
assert isinstance(model, xcit.XCiT)
# Adapt ConvPatchEmbed module
model.patch_embed.patch_size = 4
for conv_index in [0, 2]:
model.patch_embed.proj[conv_index][0].stride = (1, 1)
model = utils.adapt_model_patches(model, 4)
return model


@register_model
def xcit_large_12_p4_32(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16,
embed_dim=768,
depth=12,
num_heads=16,
eta=1.0,
tokens_norm=True,
**kwargs)
model = xcit._create_xcit('xcit_large_12_p16_224', pretrained=pretrained, **model_kwargs)
assert isinstance(model, xcit.XCiT)
model = utils.adapt_model_patches(model, 4)
return model


Expand All @@ -132,10 +142,7 @@ def xcit_small_12_p2_32(pretrained=False, **kwargs):
**kwargs)
model = xcit._create_xcit('xcit_small_12_p2_32', pretrained=pretrained, **model_kwargs)
assert isinstance(model, xcit.XCiT)
# Adapt ConvPatchEmbed module
model.patch_embed.patch_size = 2
for conv_index in [0, 2, 4]:
model.patch_embed.proj[conv_index][0].stride = (1, 1)
model = utils.adapt_model_patches(model, 2)
return model


Expand Down
16 changes: 12 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from timm.data import PreprocessCfg
from timm.data.fetcher import Fetcher
from timm.data.prefetcher_cuda import PrefetcherCuda
from timm.models import xcit
from torch import nn

import src.attacks as attacks
Expand Down Expand Up @@ -79,7 +80,6 @@ def check_bucket_zone(data_dir, prefix):

class GCSSummaryCsv(bits.monitor.SummaryCsv):
"""SummaryCSV version to work with GCS"""

def __init__(self, output_dir, filename='summary.csv'):
super().__init__(output_dir, filename)

Expand All @@ -93,7 +93,6 @@ def update(self, row_dict):


class ComputeLossFn(nn.Module):

def __init__(self, loss_fn: nn.Module):
super().__init__()
self.loss_fn = loss_fn
Expand Down Expand Up @@ -138,7 +137,6 @@ class MyPreprocessCfg(PreprocessCfg):
class ImageNormalizer(nn.Module):
"""From
https://github.com/RobustBench/robustbench/blob/master/robustbench/model_zoo/architectures/utils_architectures.py#L8"""

def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None:
super(ImageNormalizer, self).__init__()

Expand All @@ -158,7 +156,6 @@ def normalize_model(model: nn.Module, mean: Tuple[float, float, float], std: Tup


class CombinedLoaders:

def __init__(self, loader_1: Union[Fetcher, PrefetcherCuda], loader_2: Union[Fetcher, PrefetcherCuda]):
self.loader_1 = loader_1
self.loader_2 = loader_2
Expand Down Expand Up @@ -233,3 +230,14 @@ def interpolate_position_embeddings(model: nn.Module, checkpoint_model: Dict[str
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
return new_pos_embed


def adapt_model_patches(model: xcit.XCiT, new_patch_size: int):
to_divide = model.patch_embed.patch_size / new_patch_size
assert int(to_divide) == to_divide, "The new patch size should divide the original patch size"
to_divide = int(to_divide)
assert to_divide % 2 == 0, "The ratio between the original patch size and the new patch size should be divisible by 2"
for conv_index in range(0, to_divide, 2):
model.patch_embed.proj[conv_index][0].stride = (1, 1)
model.patch_embed.patch_size = new_patch_size
return model
53 changes: 49 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest
from src import setup_task
from src.utils import adapt_model_patches
from src.arg_parser import parse_args

from timm.models import xcit


def test_resolve_attack_cfg():
args, _ = parse_args()
args, _ = parse_args(additional_args=["foo"])
args.attack_eps = 4
args.attack_steps = 10
attack_cfg = setup_task.resolve_attack_cfg(args)
Expand All @@ -16,7 +20,7 @@ def test_resolve_attack_cfg():


def test_resolve_attack_cfg_attack_lr():
args, _ = parse_args()
args, _ = parse_args(additional_args=["foo"])
args.attack_eps = 4
args.attack_steps = 10
args.attack_lr = 15
Expand All @@ -26,7 +30,7 @@ def test_resolve_attack_cfg_attack_lr():


def test_resolve_attack_cfg_eval_eps():
args, _ = parse_args()
args, _ = parse_args(additional_args=["foo"])
args.eval_attack_eps = 8
args.attack_steps = 10
attack_cfg = setup_task.resolve_attack_cfg(args, eval=True)
Expand All @@ -39,8 +43,49 @@ def test_resolve_attack_cfg_eval_eps():


def test_resolve_attack_cfg_eval_name():
args, _ = parse_args()
args, _ = parse_args(additional_args=["foo"])
args.attack = "targeted_pgd"
attack_cfg = setup_task.resolve_attack_cfg(args, eval=True)

assert attack_cfg.name == "pgd"


def test_adapt_model_patches_exception():
model = xcit._create_xcit('xcit_small_12_p16_224')
patch_size = 6
assert isinstance(model, xcit.XCiT)
with pytest.raises(AssertionError):
adapt_model_patches(model, patch_size)


def test_adapt_model_patches_2():
model = xcit._create_xcit('xcit_small_12_p16_224')
patch_size = 2
assert isinstance(model, xcit.XCiT)
modified_model = adapt_model_patches(model, patch_size)
assert modified_model.patch_embed.patch_size == patch_size
assert modified_model.patch_embed.proj[0][0].stride == (1, 1)
assert modified_model.patch_embed.proj[2][0].stride == (1, 1)
assert modified_model.patch_embed.proj[4][0].stride == (1, 1)


def test_adapt_model_patches_4():
model = xcit._create_xcit('xcit_small_12_p16_224')
patch_size = 4
assert isinstance(model, xcit.XCiT)
modified_model = adapt_model_patches(model, patch_size)
assert modified_model.patch_embed.patch_size == patch_size
assert modified_model.patch_embed.proj[0][0].stride == (1, 1)
assert modified_model.patch_embed.proj[2][0].stride == (1, 1)
assert modified_model.patch_embed.proj[4][0].stride == (2, 2)


def test_adapt_model_patches_8():
model = xcit._create_xcit('xcit_small_12_p16_224')
patch_size = 8
assert isinstance(model, xcit.XCiT)
modified_model = adapt_model_patches(model, patch_size)
assert modified_model.patch_embed.patch_size == patch_size
assert modified_model.patch_embed.proj[0][0].stride == (1, 1)
assert modified_model.patch_embed.proj[2][0].stride == (2, 2)
assert modified_model.patch_embed.proj[4][0].stride == (2, 2)

0 comments on commit f90de6a

Please sign in to comment.