-
Notifications
You must be signed in to change notification settings - Fork 8
/
code_29_serving.py
156 lines (107 loc) · 4.7 KB
/
code_29_serving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 6 07:13:17 2020
@author: ljh
"""
from abc import ABC
import json
import logging
import os
import traceback
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer,AutoConfig
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class TransformersClassifierHandler(BaseHandler, ABC):
"""
Transformers text classifier handler class. This handler takes a text (string) and
as input and returns the classification text based on the serialized transformers checkpoint.
"""
def __init__(self):
super(TransformersClassifierHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
print("initialize________:",self.manifest)
properties = ctx.system_properties
model_dir = properties.get("model_dir")
# Read model serialize/pt file
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
logger.debug('Transformer initialize tokenizer: {0}'.format(model_dir))
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
# self.model.to(self.device)
self.model.eval()
logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))
# Read the mapping file, index to object name
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
if os.path.isfile(mapping_file_path):
with open(mapping_file_path) as f:
self.mapping = json.load(f)
else:
logger.warning('Missing the index_to_name.json file. Inference output will not include class name.')
self.initialized = True
print("initialize_____initialized___OK")
def preprocess(self, data):
""" Very basic preprocessing code - only tokenizes.
Extend with your own preprocessing steps as needed.
"""
text = data[0].get("data")
if text is None:
text = data[0].get("body")
sentences = text.decode('utf-8')
logger.info("Received text: '%s'", sentences)
print(sentences)
inputs = self.tokenizer.encode_plus(
sentences,
add_special_tokens=True,
return_tensors="pt"
)
return inputs
def inference(self, inputs):
"""
Predict the class of a text using a trained transformer model.
"""
# NOTE: This makes the assumption that your model expects text to be tokenized
# with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert.
# If your transformer model expects different tokenization, adapt this code to suit
# its expected input format.
prediction = self.model( inputs['input_ids'],
attention_mask = inputs['attention_mask']
)[0].argmax().item()
logger.info("Model predicted: '%s'", prediction)
if self.mapping:
prediction = self.mapping[str(prediction)]
return [prediction]
def postprocess(self, inference_output):
# TODO: Add any needed post-processing of the model predictions here
return inference_output
_service = TransformersClassifierHandler()
def handle(data, context):
try:
if not _service.initialized:
_service.initialize(context)
if data is None:
return None
data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data
except Exception as e:
traceback.print_exc()
raise e
if __name__== "__main__":
from transformers import AutoConfig,AutoModelForSequenceClassification,AutoTokenizer
config = AutoConfig.from_pretrained(r'./distilbert-base-uncased/')
modelbert = AutoModelForSequenceClassification.from_pretrained(r'./distilbert-base-uncased/', config=config)
tokenizer = AutoTokenizer.from_pretrained(r'./distilbert-base-uncased/')
# sentences = 'you are so bad'
# inputs = tokenizer.encode_plus(
# sentences,
# add_special_tokens=True,
# # return_token_type_ids = True,
# return_tensors="pt"
# )
# modelbert( inputs['input_ids'], attention_mask = inputs['attention_mask'])
NEW_DIR = 'model_store'
modelbert.save_pretrained(NEW_DIR)
tokenizer.save_pretrained(NEW_DIR)