forked from kennymckormick/pyskl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_cpu_cuda.py
217 lines (183 loc) · 7.66 KB
/
test_cpu_cuda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa: E722
import argparse
import mmcv
import os
import os.path as osp
import time
import torch
import torch.distributed as dist
from mmcv import Config
from mmcv import digit_version as dv
from mmcv import load
from mmcv.cnn import fuse_conv_bn
from mmcv.engine import multi_gpu_test
from mmcv.fileio.io import file_handlers
from mmcv.parallel import MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcv.engine import multi_gpu_test, single_gpu_test
from pyskl.datasets import build_dataloader, build_dataset
from pyskl.models import build_model
from pyskl.utils import cache_checkpoint, mc_off, mc_on, test_port
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(
description='pyskl test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('-C', '--checkpoint', help='checkpoint file', default=None)
parser.add_argument(
'--out',
default=None,
help='output result file in pkl/yaml/json format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--eval',
type=str,
nargs='+',
default=['top_k_accuracy', 'mean_class_accuracy'],
help='evaluation metrics, which depends on the dataset, e.g.,'
' "top_k_accuracy", "mean_class_accuracy" for video dataset')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple workers')
parser.add_argument(
'--average-clips',
choices=['score', 'prob', None],
default=None,
help='average type when averaging test clips')
parser.add_argument(
'--launcher',
choices=['pytorch', 'slurm'],
default='pytorch',
help='job launcher')
parser.add_argument(
'--compile',
action='store_true',
help='whether to compile the model before training / testing (only available in pytorch 2.0)')
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--local-rank', type=int, default=-1)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def move_data_to_device(batch_data, device):
if isinstance(batch_data, torch.Tensor):
return batch_data.to(device)
elif isinstance(batch_data, list):
return [move_data_to_device(item, device) for item in batch_data]
elif isinstance(batch_data, dict):
return {key: move_data_to_device(val, device) for key, val in batch_data.items()}
else:
return batch_data
def inference_pytorch(args, cfg, data_loader):
"""Get predictions by pytorch models."""
if args.average_clips is not None:
# You can set average_clips during testing, it will override the
# original setting
if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
cfg.model.setdefault('test_cfg',
dict(average_clips=args.average_clips))
else:
if cfg.model.get('test_cfg') is not None:
cfg.model.test_cfg.average_clips = args.average_clips
else:
cfg.test_cfg.average_clips = args.average_clips
# Build the model and load checkpoint
model = build_model(cfg.model)
if dv(torch.__version__) >= dv('2.0.0') and args.compile:
model = torch.compile(model)
if args.checkpoint is None:
work_dir = cfg.work_dir
args.checkpoint = osp.join(work_dir, 'latest.pth')
assert osp.exists(args.checkpoint)
args.checkpoint = cache_checkpoint(args.checkpoint)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# Use CUDA if available, otherwise run on CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Detect if CUDA is available
if torch.cuda.is_available():
outputs = [] # Store the predictions
for batch in data_loader:
batch = move_data_to_device(batch, device) # Move input to GPU
with torch.no_grad():
output = model(return_loss=False, **batch) # Output may be logits (2D)
outputs.append(output)
# Concatenate all outputs and apply argmax to get predicted classes
outputs = np.concatenate(outputs, axis=0)
#y_pred_classes = np.argmax(outputs, axis=1) # 1D array of predicted classes
outputs = [output for output in outputs] # Convert numpy array to list of numpy arrays
return outputs #y_pred_classes.tolist() # Return the outputs here
else:
# CPU fallback (you can leave this part as is)
print("Running on CPU since CUDA is not available.")
outputs = single_gpu_test(model, data_loader)
return outputs # Return outputs for CPU as well
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
out = osp.join(cfg.work_dir, 'result.pkl') if args.out is None else args.out
# Load eval_config from cfg
eval_cfg = cfg.get('evaluation', {})
keys = ['interval', 'tmpdir', 'start', 'save_best', 'rule', 'by_epoch', 'broadcast_bn_buffers']
for key in keys:
eval_cfg.pop(key, None)
if args.eval:
eval_cfg['metrics'] = args.eval
mmcv.mkdir_or_exist(osp.dirname(out))
_, suffix = osp.splitext(out)
assert suffix[1:] in file_handlers, ('The format of the output file should be json, pickle or yaml')
# set cudnn benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True
if not hasattr(cfg, 'dist_params'):
cfg.dist_params = dict(backend='nccl')
#init_dist(args.launcher, **cfg.dist_params)
rank, world_size = get_dist_info()
cfg.gpu_ids = []
#cfg.gpu_ids = range(world_size)
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
shuffle=False)
dataloader_setting = dict(dataloader_setting, **cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)
default_mc_cfg = ('localhost', 22077)
memcached = cfg.get('memcached', False)
if rank == 0 and memcached:
# mc_list is a list of pickle files you want to cache in memory.
# Basically, each pickle file is a dictionary.
mc_cfg = cfg.get('mc_cfg', default_mc_cfg)
assert isinstance(mc_cfg, tuple) and mc_cfg[0] == 'localhost'
if not test_port(mc_cfg[0], mc_cfg[1]):
mc_on(port=mc_cfg[1], launcher=args.launcher)
retry = 3
while not test_port(mc_cfg[0], mc_cfg[1]) and retry > 0:
time.sleep(5)
retry -= 1
assert retry >= 0, 'Failed to launch memcached. '
#dist.barrier()
outputs = inference_pytorch(args, cfg, data_loader)
# Save outputs
rank, _ = get_dist_info()
if rank == 0:
print(f'\nwriting results to {out}')
dataset.dump_results(outputs, out=out)
if eval_cfg:
eval_res = dataset.evaluate(outputs, **eval_cfg)
for name, val in eval_res.items():
print(f'{name}: {val:.04f}')
#dist.barrier()
if rank == 0 and memcached:
mc_off()
if __name__ == '__main__':
main()