-
Notifications
You must be signed in to change notification settings - Fork 0
/
api_controller.py
88 lines (74 loc) · 2.21 KB
/
api_controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
TensorflowKeras_DeepLearning_Samples WEB API
run using: uvicorn api_controller:app --reload
"""
import sys
from fastapi import FastAPI
from app import features
from lib import neural_networks
from io import StringIO
# shared vgg16 cnn model
model = ''
def api_init():
"""
api initialization
"""
features.initialize()
print("> api initialized")
global model
model = neural_networks.get_vgg16()
print("> model initialized")
# app api
app = FastAPI()
# shared data
api_init()
@app.get("/classify/best/{image_filename}")
def classify_best(image_filename: str):
"""
classify the image at the given path
wget http://127.0.0.1:8000/classify/best/CNN-VGG-mug.jpg
wget http://127.0.0.1:8000/classify/top-five/talbot-samba-red.jpeg
:param image_filename: absolute or relative image path
:return: most probable prediction
"""
image_filename = "data/" + image_filename
print(image_filename)
predict = features.classify_image_using_model_vgg16_cnn(model, image_filename)
return {
"type": predict.label[0][0][1],
"probability": str(predict.label[0][0][2])
}
@app.get("/classify/top-five/{image_filename}")
def classify_all(image_filename: str):
"""
classify the image at the given path
wget http://127.0.0.1:8000/classify/top-five/CNN-VGG-mug.jpg
wget http://127.0.0.1:8000/classify/top-five/talbot-samba-red.jpeg
:param image_filename: absolute or relative image path
:return: most probable prediction
"""
image_filename = "data/" + image_filename
print(image_filename)
predict = features.classify_image_using_model_vgg16_cnn(model, image_filename)
result = []
for i in range(len(predict.label[0])):
result.append({
"type": predict.label[0][i][1],
"probability": str(predict.label[0][i][2])
})
return result
@app.get("/")
def get_info():
"""
get info about api
wget http://127.0.0.1:8000/
"""
saved_stdout = sys.stdout
captured_stdout = StringIO()
sys.stdout = captured_stdout
model.summary()
sys.stdout = saved_stdout
return {
"api": "TensorflowKeras_DeepLearning_Samples",
"model": captured_stdout.getvalue()
}