-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
39 lines (33 loc) · 1.11 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
from modeltrain import modeltrain
from modelbuild import modelbuild
import os
import argparse
import deepspeed
def add_argument():
parser = argparse.ArgumentParser(description='CFD-CNN')
parser.add_argument('--local_rank',
type=int,
default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
def main():
## model name
modelname = 'MSTA_1001_10to1'
mode = "inference"
infer_num = [-10] #range(-50,-10)
infer_step = 10
min_max_delt=None
mean=False
## model path
dir_path = os.path.dirname(os.path.abspath(__file__))
ds_args = add_argument()
model_path = os.path.join(dir_path, 'Model', f'{modelname}')
total_data = modelbuild(model_path, ds_args, mode)
model_data = total_data.get_data()
model = modeltrain(model_data, model_path, mode, infer_num = infer_num, infer_step = infer_step)
model.muti_inference(min_max_delt, mean)
if __name__ == '__main__':
main()