forked from kennymckormick/pyskl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_cuda
168 lines (139 loc) · 5.72 KB
/
train_cuda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa: E722
import argparse
import mmcv
import os
import os.path as osp
import time
import torch
import torch.distributed as dist
from mmcv import Config
from mmcv import digit_version as dv
from mmcv.runner import get_dist_info, init_dist, set_random_seed
from mmcv.utils import get_git_hash
import sys
sys.path.insert(0, '/root/pyskl_thesis/pyskl')
from pyskl import __version__
from pyskl.apis import init_random_seed, train_model
from pyskl.datasets import build_dataset
from pyskl.models import build_model
from pyskl.utils import collect_env, get_root_logger, mc_off, mc_on, test_port
def parse_args():
parser = argparse.ArgumentParser(description='Train a recognizer')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
parser.add_argument(
'--test-last',
action='store_true',
help='whether to test the checkpoint after training')
parser.add_argument(
'--test-best',
action='store_true',
help='whether to test the best checkpoint (if applicable) after training')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--launcher',
choices=['pytorch', 'slurm'],
default='pytorch',
help='job launcher')
parser.add_argument(
'--compile',
action='store_true',
help='whether to compile the model before training / testing (only available in pytorch 2.0)')
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--local-rank', type=int, default=-1)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority:
# config file > default (base filename)
if cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
if not hasattr(cfg, 'dist_params'):
cfg.dist_params = dict(backend='nccl')
#init_dist(args.launcher, **cfg.dist_params)
rank, world_size = get_dist_info()
cfg.gpu_ids = [1]
#cfg.gpu_ids = range(world_size)
auto_resume = cfg.get('auto_resume', True)
if auto_resume and cfg.get('resume_from', None) is None:
resume_pth = osp.join(cfg.work_dir, 'latest.pth')
if osp.exists(resume_pth):
cfg.resume_from = resume_pth
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.get('log_level', 'INFO'))
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info(f'Config: {cfg.pretty_text}')
# set random seeds
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['config_name'] = osp.basename(args.config)
meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))
model = build_model(cfg.model)
if dv(torch.__version__) >= dv('2.0.0') and args.compile:
model = torch.compile(model)
datasets = [build_dataset(cfg.data.train)]
cfg.workflow = cfg.get('workflow', [('train', 1)])
assert len(cfg.workflow) == 1
if cfg.checkpoint_config is not None:
# save pyskl version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
pyskl_version=__version__ + get_git_hash(digits=7),
config=cfg.pretty_text)
test_option = dict(test_last=args.test_last, test_best=args.test_best)
default_mc_cfg = ('localhost', 22077)
memcached = cfg.get('memcached', False)
if rank == 0 and memcached:
# mc_list is a list of pickle files you want to cache in memory.
# Basically, each pickle file is a dictionary.
mc_cfg = cfg.get('mc_cfg', default_mc_cfg)
assert isinstance(mc_cfg, tuple) and mc_cfg[0] == 'localhost'
if not test_port(mc_cfg[0], mc_cfg[1]):
mc_on(port=mc_cfg[1], launcher=args.launcher)
retry = 3
while not test_port(mc_cfg[0], mc_cfg[1]) and retry > 0:
time.sleep(5)
retry -= 1
assert retry >= 0, 'Failed to launch memcached. '
#dist.barrier()
train_model(model, datasets, cfg, validate=args.validate, test=test_option, timestamp=timestamp, meta=meta)
#dist.barrier()
if rank == 0 and memcached:
mc_off()
if __name__ == '__main__':
main()