-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a19c591
commit 5e6aa8a
Showing
9 changed files
with
294,670 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from transformers import AutoTokenizer, PreTrainedTokenizerFast | ||
from http.server import HTTPServer, BaseHTTPRequestHandler | ||
import json | ||
import argparse | ||
|
||
|
||
class TokenizerGLM3_Http(): | ||
|
||
def __init__(self): | ||
|
||
path = 'minicpm_tokenizer' | ||
self.tokenizer = AutoTokenizer.from_pretrained(path) | ||
# self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", | ||
# trust_remote_code=True) | ||
|
||
def encode(self, prompt): | ||
# tokenizer.apply_chat_template( | ||
# prompt, | ||
# add_generation_prompt=True, | ||
# return_tensors="pt") | ||
# token_ids = self.tokenizer.apply_chat_template(prompt) | ||
history = [] | ||
history.append({"role": "user", "content": prompt}) | ||
history_str = self.tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False) | ||
print(history_str) | ||
token_ids = self.tokenizer.encode(history_str) | ||
return token_ids | ||
|
||
def decode(self, token_ids): | ||
return self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) | ||
|
||
@property | ||
def bos_id(self): | ||
return self.tokenizer.bos_token_id | ||
|
||
@property | ||
def eos_id(self): | ||
return self.tokenizer.eos_token_id | ||
|
||
@property | ||
def bos_token(self): | ||
return self.tokenizer.bos_token | ||
|
||
@property | ||
def eos_token(self): | ||
return self.tokenizer.eos_token | ||
|
||
|
||
tokenizer = TokenizerGLM3_Http() | ||
|
||
print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) | ||
print(tokenizer.encode("hello world")) | ||
|
||
|
||
class Request(BaseHTTPRequestHandler): | ||
#通过类继承,新定义类 | ||
timeout = 5 | ||
server_version = 'Apache' | ||
|
||
def do_GET(self): | ||
print(self.path) | ||
#在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) | ||
self.send_response(200) | ||
self.send_header("type","get") #设置响应头,可省略或设置多个 | ||
self.end_headers() | ||
|
||
if self.path == '/bos_id': | ||
bos_id = tokenizer.bos_id | ||
# print(bos_id) | ||
# to json | ||
if bos_id is None: | ||
msg = json.dumps({'bos_id': -1}) | ||
else: | ||
msg = json.dumps({'bos_id': bos_id}) | ||
elif self.path == '/eos_id': | ||
eos_id = tokenizer.eos_id | ||
if eos_id is None: | ||
msg = json.dumps({'eos_id': -1}) | ||
else: | ||
msg = json.dumps({'eos_id': eos_id}) | ||
else: | ||
msg = 'error' | ||
|
||
print(msg) | ||
msg = str(msg).encode() #转为str再转为byte格式 | ||
|
||
self.wfile.write(msg) #将byte格式的信息返回给客户端 | ||
|
||
def do_POST(self): | ||
#在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) | ||
data = self.rfile.read(int(self.headers['content-length'])) #获取从客户端传入的参数(byte格式) | ||
data = data.decode() #将byte格式转为str格式 | ||
|
||
self.send_response(200) | ||
self.send_header("type","post") #设置响应头,可省略或设置多个 | ||
self.end_headers() | ||
|
||
if self.path == '/encode': | ||
req = json.loads(data) | ||
prompt = req['text'] | ||
token_ids = tokenizer.encode(prompt) | ||
if token_ids is None: | ||
msg = json.dumps({'token_ids': -1}) | ||
else: | ||
msg = json.dumps({'token_ids': token_ids}) | ||
|
||
elif self.path == '/decode': | ||
req = json.loads(data) | ||
token_ids = req['token_ids'] | ||
text = tokenizer.decode(token_ids) | ||
if text is None: | ||
msg = json.dumps({'text': ""}) | ||
else: | ||
msg = json.dumps({'text': text}) | ||
else: | ||
msg = 'error' | ||
print(msg) | ||
msg = str(msg).encode() #转为str再转为byte格式 | ||
|
||
self.wfile.write(msg) #将byte格式的信息返回给客户端 | ||
|
||
if __name__ == "__main__": | ||
|
||
args = argparse.ArgumentParser() | ||
args.add_argument('--host', type=str, default='localhost') | ||
args.add_argument('--port', type=int, default=8080) | ||
args = args.parse_args() | ||
|
||
host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' | ||
print('http://%s:%s' % host) | ||
server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 | ||
server.serve_forever() #开启服务 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"bos_token": { | ||
"content": "<s>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
}, | ||
"eos_token": { | ||
"content": "</s>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
}, | ||
"unk_token": { | ||
"content": "<unk>", | ||
"lstrip": false, | ||
"normalized": false, | ||
"rstrip": false, | ||
"single_word": false | ||
} | ||
} |
Oops, something went wrong.