forked from echonax07/MMSeaIce
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
70 lines (58 loc) · 2.28 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import json
import random
import os
import os.path as osp
import shutil
from icecream import ic
import pathlib
import warnings
import numpy as np
import torch
from tqdm import tqdm # Progress bar
from mmcv import Config
from functions import create_train_validation_and_test_scene_list, get_model, load_model
from loaders import get_variable_options, AI4ArcticChallengeDataset, AI4ArcticChallengeTestDataset
from utils import colour_str
def parse_args():
parser = argparse.ArgumentParser(description='Train Default U-NET segmentor')
# Mandatory arguments
parser.add_argument('config', type=pathlib.Path, help='train config file path',)
parser.add_argument('--cnn-path', type=pathlib.Path, default=None, help='trained CNN path')
parser.add_argument('--out-dir', type=pathlib.Path, default=None, help='trained CNN path')
args = parser.parse_args()
return args
def main():
args = parse_args()
ic(args.config)
cfg = Config.fromfile(args.config)
train_options = cfg.train_options
# Get options for variables, amsrenv grid, cropping and upsampling.
train_options = get_variable_options(train_options)
create_train_validation_and_test_scene_list(train_options)
#print(train_options)
#print(args)
#print(train_options['validate_list'])
if torch.cuda.is_available():
print(colour_str('GPU available!', 'green'))
print('Total number of available devices: ',
colour_str(torch.cuda.device_count(), 'orange'))
# Check if NVIDIA V100, A100, or H100 is available for torch compile speed up
if train_options['compile_model']:
gpu_ok = False
device_cap = torch.cuda.get_device_capability()
if device_cap in ((7, 0), (8, 0), (9, 0)):
gpu_ok = True
if not gpu_ok:
warnings.warn(
colour_str("GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.", 'red')
)
# Setup device to be used
device = torch.device(f"cuda:{train_options['gpu_id']}")
net = get_model(train_options, device)
if train_options['compile_model']:
net = torch.compile(net)
_ = load_model(net, args.cnn_path)
print(net)
if __name__ == '__main__':
main()