forked from MahmoudWahdan/dialog-nlu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_nlu_basic_api.py
113 lines (89 loc) · 3.62 KB
/
bert_nlu_basic_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 11 15:12:18 2020
@author: mwahdan
"""
from flask import Flask, jsonify, request
from vectorizers.bert_vectorizer import BERTVectorizer
from models.joint_bert import JointBertModel
from utils import convert_to_slots
from tensorflow.python.keras.backend import set_session
import tensorflow as tf
import pickle
import argparse
import os
# Create app
app = Flask(__name__)
def initialize():
global graph
graph = tf.get_default_graph()
global sess
sess = tf.compat.v1.Session()
set_session(sess)
global bert_vectorizer
bert_vectorizer = BERTVectorizer(sess, is_bert, bert_model_hub_path)
# loading models
print('Loading models ...')
if not os.path.exists(load_folder_path):
print('Folder `%s` not exist' % load_folder_path)
global slots_num
global tags_vectorizer
with open(os.path.join(load_folder_path, 'tags_vectorizer.pkl'), 'rb') as handle:
tags_vectorizer = pickle.load(handle)
slots_num = len(tags_vectorizer.label_encoder.classes_)
global intents_num
global intents_label_encoder
with open(os.path.join(load_folder_path, 'intents_label_encoder.pkl'), 'rb') as handle:
intents_label_encoder = pickle.load(handle)
intents_num = len(intents_label_encoder.classes_)
global model
model = JointBertModel.load(load_folder_path, sess)
@app.route('/', methods=['GET', 'POST'])
def hello():
return 'hello from NLU service'
@app.route('/predict', methods=['GET', 'POST'])
def predict():
global sess
global graph
with graph.as_default():
set_session(sess)
input_json = request.json
utterance = input_json["utterance"]
tokens = utterance.split()
print(utterance)
input_ids, input_mask, segment_ids, valid_positions, data_sequence_lengths = bert_vectorizer.transform([utterance])
predicted_tags, predicted_intents = model.predict_slots_intent(
[input_ids, input_mask, segment_ids, valid_positions],
tags_vectorizer, intents_label_encoder, remove_start_end=True,
include_intent_prob=True)
slots = convert_to_slots(predicted_tags[0])
slots = [{"slot": slot, "start": start, "end": end, "value": ' '.join(tokens[start:end + 1])} for slot, start, end in slots]
response = {
"intent": {
"name": predicted_intents[0][0].strip(),
"confidence": predicted_intents[0][1]
},
"slots": slots
}
return jsonify(response)
if __name__ == '__main__':
VALID_TYPES = ['bert', 'albert']
# read command-line parameters
parser = argparse.ArgumentParser('Running Joint BERT / ALBERT NLU model basic service')
parser.add_argument('--model', '-m', help = 'Path to joint BERT / ALBERT NLU model', type = str, required = True)
parser.add_argument('--type', '-tp', help = 'bert or albert', type = str, default = 'bert', required = False)
args = parser.parse_args()
load_folder_path = args.model
type_ = args.type
if type_ == 'bert':
bert_model_hub_path = 'https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1'
is_bert = True
elif type_ == 'albert':
bert_model_hub_path = 'https://tfhub.dev/google/albert_base/1'
is_bert = False
else:
raise ValueError('type must be one of these values: %s' % str(VALID_TYPES))
print(('Starting the Server'))
initialize()
# Run app
app.run(host='0.0.0.0', port=5000, debug=True, use_reloader=False)