-
Notifications
You must be signed in to change notification settings - Fork 2
/
demo.py
84 lines (72 loc) · 2.74 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# coding=utf-8
import os
import sys
import importlib
import time
from flask import request, send_from_directory
from flask import Flask, request, redirect, url_for
import uuid
import tensorflow as tf
from classify_image import run_inference_on_image
ALLOWED_EXTENSIONS = set(['jpg','JPG', 'jpeg', 'JPEG', 'png'])
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('upload_folder', '/tmp/', '')
tf.app.flags.DEFINE_integer('num_top_predictions', 5,
"""Display this many predictions.""")
tf.app.flags.DEFINE_integer('port', '5001',
'server with port,if no port, use deault port 80')
tf.app.flags.DEFINE_boolean('debug', False, '')
UPLOAD_FOLDER = FLAGS.upload_folder
ALLOWED_EXTENSIONS = set(['jpg','JPG', 'jpeg', 'JPEG', 'png'])
app = Flask(__name__)
app._static_folder = UPLOAD_FOLDER
def allowed_files(filename):
return ('.' in filename) and (filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS)
def rename_filename(old_file_name):
basename = os.path.basename(old_file_name)
name, ext = os.path.splitext(basename)
new_name = str(uuid.uuid1()) + ext
return new_name
def inference(file_name):
try:
predictions, top_k, top_names = run_inference_on_image(file_name)
print(predictions)
except Exception as ex:
print(ex)
return ""
new_url = '/static/%s' % os.path.basename(file_name)
image_tag = '<img src="%s" width="200" height="200"></img><p>'
new_tag = image_tag % new_url
format_string = ''
for node_id, human_name in zip(top_k, top_names):
score = predictions[node_id]
format_string += '%s (score:%.5f)<BR>' % (node_id, score)
ret_string = new_tag + format_string + '<BR>'
return ret_string
@app.route("/", methods=['GET', 'POST'])
def root():
result = """
<!doctype html>
<title>CapsNet Demo</title>
<h1>Please upload an overlapping digit picture</h1>
<form action="" method=post enctype=multipart/form-data>
<p><input type=file name=file value='Select'>
<input type=submit value='Upload'>
</form>
<p>%s</p>
""" % "<br>"
if request.method == 'POST':
file = request.files['file']
old_file_name = file.filename
if file and allowed_files(old_file_name):
filename = rename_filename(old_file_name)
file_path = os.path.join(UPLOAD_FOLDER, filename)
file.save(file_path)
type_name = 'N/A'
print('file saved to %s' % file_path)
out_html = inference(file_path)
return result + out_html
return result
if __name__ == "__main__":
print('listening on port %d' % FLAGS.port)
app.run(host='0.0.0.0', port=FLAGS.port, debug=FLAGS.debug, threaded=True)