-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathinference.py
122 lines (88 loc) · 3.29 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
sys.path.append('core')
from PIL import Image
import argparse
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from core.utils.misc import process_cfg
from utils import flow_viz
from core.Networks import build_network
from utils import frame_utils
from utils.utils import InputPadder, forward_interpolate
import itertools
import imageio
def prepare_image(seq_dir):
print(f"preparing image...")
print(f"Input image sequence dir = {seq_dir}")
images = []
image_list = sorted(os.listdir(seq_dir))
for fn in image_list:
img = Image.open(os.path.join(seq_dir, fn))
img = np.array(img).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
images.append(img)
return torch.stack(images)
def vis_pre(flow_pre, vis_dir):
if not os.path.exists(vis_dir):
os.makedirs(vis_dir)
N = flow_pre.shape[0]
for idx in range(N//2):
flow_img = flow_viz.flow_to_image(flow_pre[idx].permute(1, 2, 0).numpy())
image = Image.fromarray(flow_img)
image.save('{}/flow_{:04}_to_{:04}.png'.format(vis_dir, idx+2, idx+3))
for idx in range(N//2, N):
flow_img = flow_viz.flow_to_image(flow_pre[idx].permute(1, 2, 0).numpy())
image = Image.fromarray(flow_img)
image.save('{}/flow_{:04}_to_{:04}.png'.format(vis_dir, idx-N//2+2, idx-N//2+1))
@torch.no_grad()
def MOF_inference(model, cfg):
model.eval()
input_images = prepare_image(cfg.seq_dir)
input_images = input_images[None].cuda()
padder = InputPadder(input_images.shape)
input_images = padder.pad(input_images)
flow_pre, _ = model(input_images, {})
flow_pre = padder.unpad(flow_pre[0]).cpu()
return flow_pre
@torch.no_grad()
def BOF_inference(model, cfg):
model.eval()
input_images = prepare_image(cfg.seq_dir)
input_images = input_images[None].cuda()
padder = InputPadder(input_images.shape)
input_images = padder.pad(input_images)
flow_pre, _ = model(input_images, {})
flow_pre = padder.unpad(flow_pre[0]).cpu()
return flow_pre
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='MOF')
parser.add_argument('--seq_dir', default='default')
parser.add_argument('--vis_dir', default='default')
args = parser.parse_args()
if args.mode == 'MOF':
from configs.multiframes_sintel_submission import get_cfg
elif args.mode == 'BOF':
from configs.sintel_submission import get_cfg
cfg = get_cfg()
cfg.update(vars(args))
model = torch.nn.DataParallel(build_network(cfg))
model.load_state_dict(torch.load(cfg.model))
model.cuda()
model.eval()
print(cfg.model)
print("Parameter Count: %d" % count_parameters(model))
with torch.no_grad():
if args.mode == 'MOF':
from configs.multiframes_sintel_submission import get_cfg
flow_pre = MOF_inference(model.module, cfg)
elif args.mode == 'BOF':
from configs.sintel_submission import get_cfg
flow_pre = BOF_inference(model.module, cfg)
vis_pre(flow_pre, cfg.vis_dir)