-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathmain_seg.py
191 lines (149 loc) · 6.06 KB
/
main_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# ------------------------------------------------------------------------------
# Copyright (c) 2021-2022, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License.
# To view a copy of this license, visit
# https://github.com/NVlabs/GroupViT/blob/main/LICENSE
#
# Written by Jiarui Xu
# ------------------------------------------------------------------------------
import argparse
import os
import os.path as osp
import mmcv
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from datasets import build_text_transform
from main_group_vit import validate_seg
from mmcv.image import tensor2imgs
from mmcv.parallel import MMDistributedDataParallel
from mmcv.runner import set_random_seed
from models import build_model
from omegaconf import OmegaConf, read_write
from segmentation.evaluation import build_seg_dataloader, build_seg_dataset, build_seg_inference
from utils import get_config, get_logger, load_checkpoint
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def parse_args():
parser = argparse.ArgumentParser('GroupViT segmentation evaluation and visualization')
parser.add_argument(
'--cfg',
type=str,
required=True,
help='path to config file',
)
parser.add_argument(
'--opts',
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument(
'--output', type=str, help='root of output folder, '
'the full path is <output>/<model_name>/<tag>')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument(
'--vis',
help='Specify the visualization mode, '
'could be a list, support input, pred, input_seg, input_pred_seg_label, all_groups, first_group, last_group',
default=None,
nargs='+')
# distributed training
parser.add_argument('--local_rank', type=int, required=True, help='local rank for DistributedDataParallel')
args = parser.parse_args()
return args
def inference(cfg):
logger = get_logger()
data_loader = build_seg_dataloader(build_seg_dataset(cfg.evaluate.seg))
dataset = data_loader.dataset
logger.info(f'Evaluating dataset: {dataset}')
logger.info(f'Creating model:{cfg.model.type}/{cfg.model_name}')
model = build_model(cfg.model)
model.cuda()
logger.info(str(model))
if cfg.train.amp_opt_level != 'O0':
model = amp.initialize(model, None, opt_level=cfg.train.amp_opt_level)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f'number of params: {n_parameters}')
load_checkpoint(cfg, model, None, None)
if 'seg' in cfg.evaluate.task:
miou = validate_seg(cfg, data_loader, model)
logger.info(f'mIoU of the network on the {len(data_loader.dataset)} test images: {miou:.2f}%')
else:
logger.info('No segmentation evaluation specified')
if cfg.vis:
vis_seg(cfg, data_loader, model, cfg.vis)
@torch.no_grad()
def vis_seg(config, data_loader, model, vis_modes):
dist.barrier()
model.eval()
if hasattr(model, 'module'):
model_without_ddp = model.module
else:
model_without_ddp = model
text_transform = build_text_transform(False, config.data.text_aug, with_dc=False)
seg_model = build_seg_inference(model_without_ddp, data_loader.dataset, text_transform, config.evaluate.seg)
mmddp_model = MMDistributedDataParallel(
seg_model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False)
mmddp_model.eval()
model = mmddp_model.module
device = next(model.parameters()).device
dataset = data_loader.dataset
if dist.get_rank() == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
loader_indices = data_loader.batch_sampler
for batch_indices, data in zip(loader_indices, data_loader):
with torch.no_grad():
result = mmddp_model(return_loss=False, **data)
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for batch_idx, img, img_meta in zip(batch_indices, imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
for vis_mode in vis_modes:
out_file = osp.join(config.output, 'vis_imgs', vis_mode, f'{batch_idx:04d}.jpg')
model.show_result(img_show, img_tensor.to(device), result, out_file, vis_mode)
if dist.get_rank() == 0:
batch_size = len(result) * dist.get_world_size()
for _ in range(batch_size):
prog_bar.update()
def main():
args = parse_args()
cfg = get_config(args)
if cfg.train.amp_opt_level != 'O0':
assert amp is not None, 'amp not installed!'
with read_write(cfg):
cfg.evaluate.eval_only = True
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}')
else:
rank = -1
world_size = -1
torch.cuda.set_device(cfg.local_rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
dist.barrier()
set_random_seed(cfg.seed, use_rank_shift=True)
cudnn.benchmark = True
os.makedirs(cfg.output, exist_ok=True)
logger = get_logger(cfg)
if dist.get_rank() == 0:
path = os.path.join(cfg.output, 'config.json')
OmegaConf.save(cfg, path)
logger.info(f'Full config saved to {path}')
# print config
logger.info(OmegaConf.to_yaml(cfg))
inference(cfg)
dist.barrier()
if __name__ == '__main__':
main()