-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
153 lines (140 loc) · 7.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os.path
import sys
import argparse
import datetime
import random
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
from pathlib import Path
from timm.models import create_model
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from datasets import build_continual_dataloader
import utils
import warnings
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
warnings.filterwarnings('ignore', 'Argument interpolation should be of type InterpolationMode instead of int')
def get_args():
parser = argparse.ArgumentParser('DualPrompt training and evaluation configs')
config = parser.parse_known_args()[-1][0]
subparser = parser.add_subparsers(dest='subparser_name')
if config == 'cifar100_hideprompt_5e':
from configs.cifar100_hideprompt_5e import get_args_parser
config_parser = subparser.add_parser('cifar100_hideprompt_5e', help='Split-CIFAR100 HiDe-Prompt configs')
elif config == 'imr_hideprompt_5e':
from configs.imr_hideprompt_5e import get_args_parser
config_parser = subparser.add_parser('imr_hideprompt_5e', help='Split-ImageNet-R HiDe-Prompt configs')
elif config == 'five_datasets_hideprompt_5e':
from configs.five_datasets_hideprompt_5e import get_args_parser
config_parser = subparser.add_parser('five_datasets_hideprompt_5e', help='five datasets HiDe-Prompt configs')
elif config == 'cub_hideprompt_5e':
from configs.cub_hideprompt_5e import get_args_parser
config_parser = subparser.add_parser('cub_hideprompt_5e', help='Split-CUB HiDe-Prompt configs')
elif config == 'cifar100_dualprompt':
from configs.cifar100_dualprompt import get_args_parser
config_parser = subparser.add_parser('cifar100_dualprompt', help='Split-CIFAR100 dual-prompt configs')
elif config == 'imr_dualprompt':
from configs.imr_dualprompt import get_args_parser
config_parser = subparser.add_parser('imr_dualprompt', help='Split-ImageNet-R dual-prompt configs')
elif config == 'five_datasets_dualprompt':
from configs.five_datasets_dualprompt import get_args_parser
config_parser = subparser.add_parser('five_datasets_dualprompt', help='five datasets dual-prompt configs')
elif config == 'cub_dualprompt':
from configs.cub_dualprompt import get_args_parser
config_parser = subparser.add_parser('cub_dualprompt', help='Split-CUB dual-prompt configs')
elif config == 'cifar100_sprompt_5e':
from configs.cifar100_sprompt_5e import get_args_parser
config_parser = subparser.add_parser('cifar100_sprompt_5e', help='Split-CIFAR100 s-prompt configs')
elif config == 'imr_sprompt_5e':
from configs.imr_sprompt_5e import get_args_parser
config_parser = subparser.add_parser('imr_sprompt_5e', help='Split-ImageNet-R s-prompt configs')
elif config == 'five_datasets_sprompt_5e':
from configs.five_datasets_sprompt_5e import get_args_parser
config_parser = subparser.add_parser('five_datasets_sprompt_5e', help='five datasets s-prompt configs')
elif config == 'cub_sprompt_5e':
from configs.cub_sprompt_5e import get_args_parser
config_parser = subparser.add_parser('cub_sprompt_5e', help='Split-CUB s-prompt configs')
elif config == 'cifar100_l2p':
from configs.cifar100_l2p import get_args_parser
config_parser = subparser.add_parser('cifar100_l2p', help='Split-CIFAR100 l2p configs')
elif config == 'imr_l2p':
from configs.imr_l2p import get_args_parser
config_parser = subparser.add_parser('imr_l2p', help='Split-ImageNet-R l2p configs')
elif config == 'five_datasets_l2p':
from configs.five_datasets_l2p import get_args_parser
config_parser = subparser.add_parser('five_datasets_l2p', help='five datasets l2p configs')
elif config == 'cub_l2p':
from configs.cub_l2p import get_args_parser
config_parser = subparser.add_parser('cub_l2p', help='Split-CUB l2p configs')
elif config == 'cifar100_hidelora':
from configs.cifar100_hidelora import get_args_parser
config_parser = subparser.add_parser('cifar100_hidelora', help='Split-CIFAR100 hidelora configs')
elif config == 'imr_hidelora':
from configs.imr_hidelora import get_args_parser
config_parser = subparser.add_parser('imr_hidelora', help='Split-ImageNet-R hidelora configs')
elif config == 'cifar100_continual_lora':
from configs.cifar100_continual_lora import get_args_parser
config_parser = subparser.add_parser('cifar100_continual_lora', help='Split-CIFAR100 continual lora configs')
elif config == 'imr_continual_lora':
from configs.imr_continual_lora import get_args_parser
config_parser = subparser.add_parser('imr_continual_lora', help='Split-ImageNet-R continual lora configs')
elif config == 'cifar100_hideadapter':
from configs.cifar100_hideadapter import get_args_parser
config_parser = subparser.add_parser('cifar100_hideadapter', help='Split-CIFAR100 hideadapter configs')
elif config == 'imr_hideadapter':
from configs.imr_hideadapter import get_args_parser
config_parser = subparser.add_parser('imr_hideadapter', help='Split-ImageNet-R hideadapter configs')
elif config == 'imr_continual_prompt':
from configs.imr_continual_prompt import get_args_parser
config_parser = subparser.add_parser('imr_continual_prompt', help='Split-ImageNet-R continual prompt config')
elif config == 'imr_continual_adapter':
from configs.imr_continual_adapter import get_args_parser
config_parser = subparser.add_parser('imr_continual_adapter', help='Split-ImageNet-R continual adapter config')
else:
raise NotImplementedError
get_args_parser(config_parser)
args = parser.parse_args()
args.config = config
return args
def main(args):
utils.init_distributed_mode(args)
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
# fix the seed for reproducibility
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
if hasattr(args, 'train_inference_task_only') and args.train_inference_task_only:
import trainers.tii_trainer as tii_trainer
tii_trainer.train(args)
elif 'hideprompt' in args.config and not args.train_inference_task_only:
import trainers.hideprompt_trainer as hideprompt_trainer
hideprompt_trainer.train(args)
elif 'l2p' in args.config or 'dualprompt' in args.config or 'sprompt' in args.config:
import trainers.dp_trainer as dp_trainer
dp_trainer.train(args)
elif 'hidelora' in args.config and not args.train_inference_task_only:
import trainers.hidelora_trainer as hidelora_trainer
hidelora_trainer.train(args)
elif 'continual_lora' in args.config:
import trainers.continual_lora_trainer as continual_lora_trainer
continual_lora_trainer.train(args)
elif 'hideadapter' in args.config and not args.train_inference_task_only:
import trainers.hideadapter_trainer as hideapater_trainer
hideapater_trainer.train(args)
elif 'continual_prompt' in args.config:
import trainers.continual_prompt_trainer as continual_prompt_trainer
continual_prompt_trainer.train(args)
elif 'continual_adapter' in args.config:
import trainers.continual_adapter_trainer as continual_adapter_trainer
continual_adapter_trainer.train(args)
else:
raise NotImplementedError
if __name__ == '__main__':
args = get_args()
print(args)
main(args)