-
Notifications
You must be signed in to change notification settings - Fork 14
/
datasets.py
143 lines (126 loc) · 7.21 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import scipy.misc as misc
import cv2
import matplotlib.pyplot as plt
from flowlib import read_flo, read_pfm
from data_augmentation import *
from utils import imshow
class BasicDataset(object):
def __init__(self, crop_h=320, crop_w=896, batch_size=4, data_list_file='path_to_your_data_list_file',
img_dir='path_to_your_image_directory', fake_flow_occ_dir='path_to_your_fake_flow_occlusion_directory'):
self.crop_h = crop_h
self.crop_w = crop_w
self.batch_size = batch_size
self.img_dir = img_dir
self.data_list = np.loadtxt(data_list_file, dtype=np.str)
self.data_num = self.data_list.shape[0]
self.fake_flow_occ_dir = fake_flow_occ_dir
# KITTI's data format for storing flow and mask
# The first two channels are flow, the third channel is mask
def extract_flow_and_mask(self, flow):
optical_flow = flow[:, :, :2]
optical_flow = (optical_flow - 32768) / 64.0
mask = tf.cast(tf.greater(flow[:, :, 2], 0), tf.float32)
#mask = tf.cast(flow[:, :, 2], tf.float32)
mask = tf.expand_dims(mask, -1)
return optical_flow, mask
# The default image type is PNG.
def read_and_decode(self, filename_queue):
img1_name = tf.string_join([self.img_dir, '/', filename_queue[0]])
img2_name = tf.string_join([self.img_dir, '/', filename_queue[1]])
img1 = tf.image.decode_png(tf.read_file(img1_name), channels=3)
img1 = tf.cast(img1, tf.float32)
img2 = tf.image.decode_png(tf.read_file(img2_name), channels=3)
img2 = tf.cast(img2, tf.float32)
return img1, img2
# For Flying Chairs, the image type is ppm, please use "read_and_decode_ppm" instead of "read_and_decode".
# Similarily, for other image types, please write their decode functions by yourself.
def read_and_decode_ppm(self, filename_queue):
def read_ppm(self, filename):
img = misc.imread(filename).astype('float32')
return img
flying_h = 384
flying_w = 512
img1_name = tf.string_join([self.img_dir, '/', filename_queue[0]])
img2_name = tf.string_join([self.img_dir, '/', filename_queue[1]])
img1 = tf.py_func(read_ppm, [img1_name], tf.float32)
img2 = tf.py_func(read_ppm, [img2_name], tf.float32)
img1 = tf.reshape(img1, [flying_h, flying_w, 3])
img2 = tf.reshape(img2, [flying_h, flying_w, 3])
return img1, img2
def read_and_decode_distillation(self, filename_queue):
img1_name = tf.string_join([self.img_dir, '/', filename_queue[0]])
img2_name = tf.string_join([self.img_dir, '/', filename_queue[1]])
img1 = tf.image.decode_png(tf.read_file(img1_name), channels=3)
img1 = tf.cast(img1, tf.float32)
img2 = tf.image.decode_png(tf.read_file(img2_name), channels=3)
img2 = tf.cast(img2, tf.float32)
flow_occ_fw_name = tf.string_join([self.fake_flow_occ_dir, '/flow_occ_fw_', filename_queue[2], '.png'])
flow_occ_bw_name = tf.string_join([self.fake_flow_occ_dir, '/flow_occ_bw_', filename_queue[2], '.png'])
flow_occ_fw = tf.image.decode_png(tf.read_file(flow_occ_fw_name), dtype=tf.uint16, channels=3)
flow_occ_fw = tf.cast(flow_occ_fw, tf.float32)
flow_occ_bw = tf.image.decode_png(tf.read_file(flow_occ_bw_name), dtype=tf.uint16, channels=3)
flow_occ_bw = tf.cast(flow_occ_bw, tf.float32)
flow_fw, occ_fw = self.extract_flow_and_mask(flow_occ_fw)
flow_bw, occ_bw = self.extract_flow_and_mask(flow_occ_bw)
return img1, img2, flow_fw, flow_bw, occ_fw, occ_bw
def augmentation(self, img1, img2):
img1, img2 = random_crop([img1, img2], self.crop_h, self.crop_w)
img1, img2 = random_flip([img1, img2])
img1, img2 = random_channel_swap([img1, img2])
return img1, img2
def augmentation_distillation(self, img1, img2, flow_fw, flow_bw, occ_fw, occ_bw):
[img1, img2, flow_fw, flow_bw, occ_fw, occ_bw] = random_crop([img1, img2, flow_fw, flow_bw, occ_fw, occ_bw], self.crop_h, self.crop_w)
[img1, img2, occ_fw, occ_bw], [flow_fw, flow_bw] = random_flip_with_flow([img1, img2, occ_fw, occ_bw], [flow_fw, flow_bw])
img1, img2 = random_channel_swap([img1, img2])
return img1, img2, flow_fw, flow_bw, occ_fw, occ_bw
def preprocess_augmentation(self, filename_queue):
img1, img2 = self.read_and_decode(filename_queue)
img1 = img1 / 255.
img2 = img2 / 255.
img1, img2 = self.augmentation(img1, img2)
return img1, img2
def preprocess_augmentation_distillation(self, filename_queue):
img1, img2, flow_fw, flow_bw, occ_fw, occ_bw = self.read_and_decode_distillation(filename_queue)
img1 = img1 / 255.
img2 = img2 / 255.
img1, img2, flow_fw, flow_bw, occ_fw, occ_bw = self.augmentation_distillation(img1, img2, flow_fw, flow_bw, occ_fw, occ_bw)
return img1, img2, flow_fw, flow_bw, occ_fw, occ_bw
def preprocess_one_shot(self, filename_queue):
img1, img2 = self.read_and_decode(filename_queue)
img1 = img1 / 255.
img2 = img2 / 255.
return img1, img2
def create_batch_iterator(self, data_list, batch_size, shuffle=True, buffer_size=5000, num_parallel_calls=4):
data_list = tf.convert_to_tensor(data_list, dtype=tf.string)
dataset = tf.data.Dataset.from_tensor_slices(data_list)
dataset = dataset.map(self.preprocess_augmentation, num_parallel_calls=num_parallel_calls)
if shuffle:
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
return iterator
def create_batch_distillation_iterator(self, data_list, batch_size, shuffle=True, buffer_size=5000, num_parallel_calls=4):
data_list = tf.convert_to_tensor(data_list, dtype=tf.string)
dataset = tf.data.Dataset.from_tensor_slices(data_list)
dataset = dataset.map(self.preprocess_augmentation_distillation, num_parallel_calls=num_parallel_calls)
if shuffle:
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
return iterator
def create_one_shot_iterator(self, data_list, num_parallel_calls=4):
""" For Validation or Testing
Generate image and flow one_by_one without cropping, image and flow size may change every iteration
"""
data_list = tf.convert_to_tensor(data_list, dtype=tf.string)
dataset = tf.data.Dataset.from_tensor_slices(data_list)
dataset = dataset.map(self.preprocess_one_shot, num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(1)
dataset = dataset.repeat()
iterator = dataset.make_initializable_iterator()
return iterator