-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
63 lines (48 loc) · 2.02 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import filetype
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import cv2 as cv
from argparse import ArgumentParser
from src.utils import is_valid_directory
from src.model import ResNet, ResNetUnet
from settings.config import Config
import warnings
warnings.filterwarnings('ignore')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--input', type=str, required=True,
help='path to folder with images to segment')
parser.add_argument('--output', type=str, required=True,
help='folder path to segmented images')
args = parser.parse_args()
is_valid_directory(parser, args.input)
is_valid_directory(parser, args.output)
if tf.config.list_physical_devices('GPU'):
print("TensorFlow has detected GPUs.")
else:
print("No GPUs found. TensorFlow is using CPU.")
resnet = ResNet().build_model()
resnet.load_weights(os.path.join(Config.model_path, 'resnet.h5'))
unet = ResNetUnet().build_model()
unet.load_weights(os.path.join(Config.model_path, 'unet.h5'))
for el in os.listdir(args.input):
fl_path = os.path.join(args.input, el)
if not filetype.is_image(fl_path):
continue
image = cv.imread(fl_path).astype(np.float32)
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)[np.newaxis, :, :, :]
image = image / 255.0
classification_pred = resnet(image).numpy().flatten()[0]
# TODO pretrian classification model
classification_pred = 1
if classification_pred:
print(f'{fl_path}: objects detected.')
mask = unet(image)
mask = tf.cast(mask >= Config.threshold, tf.float32)
mask = tf.round(mask).numpy()[0, :, :, 0]
else:
print(f'{fl_path}: objects not detected.')
mask = np.zeros(image.shape[1:3])
cv.imwrite(os.path.join(args.output, el), mask * 255)