-
Notifications
You must be signed in to change notification settings - Fork 0
/
serve.py
29 lines (23 loc) · 1021 Bytes
/
serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import pickle
import tensorflow as tf
from config import GENRES
tf.get_logger().setLevel('ERROR')
class ModelServer:
def __init__(self, ENCODER_PATH, TRAINED_MODEL_PATH):
self.encoder = tf.saved_model.load(ENCODER_PATH)
self.trained_model = tf.saved_model.load(TRAINED_MODEL_PATH)
self.reverse_dict = pickle.load(open(os.path.join(TRAINED_MODEL_PATH, 'reverse_dict.pkl'), 'rb'))
self.threshold = pickle.load(open(os.path.join(TRAINED_MODEL_PATH, 'hyperparams.pkl'), 'rb'))['threshold']
def predict(self, description):
f = self.encoder(tf.constant(description))
predictions = self.trained_model(f)
preds = tf.cast(tf.greater(predictions, self.threshold), tf.float32)
genres = []
for single_sentence in preds:
temp = []
for i, p in enumerate(single_sentence):
if p == 1:
temp.append(GENRES[self.reverse_dict[i]])
genres.append(temp)
return genres